1.修改代码适配阿里云的服务器

This commit is contained in:
whm
2026-03-17 14:27:32 +08:00
parent 826617d737
commit 20e7f3a65d
1777 changed files with 775041 additions and 10 deletions

View File

@@ -0,0 +1,208 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"fmt"
"io"
"strconv"
"strings"
)
// NewArrayLengthError creates and returns an error for when the length of an array exceeds the
// bytes available.
func NewArrayLengthError(length, rem int) error {
return lengthError("array", length, rem)
}
// Array is a raw bytes representation of a BSON array.
type Array []byte
// NewArrayFromReader reads an array from r. This function will only validate the length is
// correct and that the array ends with a null byte.
func NewArrayFromReader(r io.Reader) (Array, error) {
return newBufferFromReader(r)
}
// Index searches for and retrieves the value at the given index. This method will panic if
// the array is invalid or if the index is out of bounds.
func (a Array) Index(index uint) Value {
value, err := a.IndexErr(index)
if err != nil {
panic(err)
}
return value
}
// IndexErr searches for and retrieves the value at the given index.
func (a Array) IndexErr(index uint) (Value, error) {
elem, err := indexErr(a, index)
if err != nil {
return Value{}, err
}
return elem.Value(), err
}
// DebugString outputs a human readable version of Array. It will attempt to stringify the
// valid components of the array even if the entire array is not valid.
func (a Array) DebugString() string {
if len(a) < 5 {
return "<malformed>"
}
var buf strings.Builder
buf.WriteString("Array")
length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length
buf.WriteByte('(')
buf.WriteString(strconv.Itoa(int(length)))
length -= 4
buf.WriteString(")[")
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
break
}
buf.WriteString(elem.Value().DebugString())
if length != 1 {
buf.WriteByte(',')
}
}
buf.WriteByte(']')
return buf.String()
}
// String outputs an ExtendedJSON version of Array. If the Array is not valid, this method
// returns an empty string.
func (a Array) String() string {
str, _ := a.StringN(-1)
return str
}
// StringN stringifies an array. If N is non-negative, it will truncate the string to N bytes.
// Otherwise, it will return the full string representation. The second return value indicates
// whether the string was truncated or not.
func (a Array) StringN(n int) (string, bool) {
length, rem, ok := ReadLength(a)
if !ok || length < 5 {
return "", false
}
length -= 4 // length bytes
length-- // final null byte
if n == 0 {
return "", true
}
var buf strings.Builder
buf.WriteByte('[')
var truncated bool
var elem Element
var str string
first := true
for length > 0 && !truncated {
needStrLen := -1
// Set needStrLen if n is positive, meaning we want to limit the string length.
if n > 0 {
// Stop stringifying if we reach the limit, that also ensures needStrLen is
// greater than 0 if we need to limit the length.
if buf.Len() >= n {
truncated = true
break
}
needStrLen = n - buf.Len()
}
// Append a comma if this is not the first element.
if !first {
buf.WriteByte(',')
// If we are truncating, we need to account for the comma in the length.
if needStrLen > 0 {
needStrLen--
if needStrLen == 0 {
truncated = true
break
}
}
}
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
// Exit on malformed element.
if !ok || length < 0 {
return "", false
}
// Delegate to StringN() on the element.
str, truncated = elem.Value().StringN(needStrLen)
buf.WriteString(str)
first = false
}
if n <= 0 || (buf.Len() < n && !truncated) {
buf.WriteByte(']')
} else {
truncated = true
}
return buf.String(), truncated
}
// Values returns this array as a slice of values. The returned slice will contain valid values.
// If the array is not valid, the values up to the invalid point will be returned along with an
// error.
func (a Array) Values() ([]Value, error) {
return values(a)
}
// Validate validates the array and ensures the elements contained within are valid.
func (a Array) Validate() error {
length, rem, ok := ReadLength(a)
if !ok {
return NewInsufficientBytesError(a, rem)
}
if int(length) > len(a) {
return NewArrayLengthError(int(length), len(a))
}
if a[length-1] != 0x00 {
return ErrMissingNull
}
length -= 4
var elem Element
var keyNum int64
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return NewInsufficientBytesError(a, rem)
}
// validate element
err := elem.Validate()
if err != nil {
return err
}
// validate keys increase numerically
if fmt.Sprint(keyNum) != elem.Key() {
return fmt.Errorf("array key %q is out of order or invalid", elem.Key())
}
keyNum++
}
if len(rem) < 1 || rem[0] != 0x00 {
return ErrMissingNull
}
return nil
}

View File

@@ -0,0 +1,198 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"strconv"
)
// ArrayBuilder builds a bson array
type ArrayBuilder struct {
arr []byte
indexes []int32
keys []int
}
// NewArrayBuilder creates a new ArrayBuilder
func NewArrayBuilder() *ArrayBuilder {
return (&ArrayBuilder{}).startArray()
}
// startArray reserves the array's length and sets the index to where the length begins
func (a *ArrayBuilder) startArray() *ArrayBuilder {
var index int32
index, a.arr = AppendArrayStart(a.arr)
a.indexes = append(a.indexes, index)
a.keys = append(a.keys, 0)
return a
}
// Build updates the length of the array and index to the beginning of the documents length
// bytes, then returns the array (bson bytes)
func (a *ArrayBuilder) Build() Array {
lastIndex := len(a.indexes) - 1
lastKey := len(a.keys) - 1
a.arr, _ = AppendArrayEnd(a.arr, a.indexes[lastIndex])
a.indexes = a.indexes[:lastIndex]
a.keys = a.keys[:lastKey]
return a.arr
}
// incrementKey() increments the value keys and returns the key to be used to a.appendArray* functions
func (a *ArrayBuilder) incrementKey() string {
idx := len(a.keys) - 1
key := strconv.Itoa(a.keys[idx])
a.keys[idx]++
return key
}
// AppendInt32 will append i32 to ArrayBuilder.arr
func (a *ArrayBuilder) AppendInt32(i32 int32) *ArrayBuilder {
a.arr = AppendInt32Element(a.arr, a.incrementKey(), i32)
return a
}
// AppendDocument will append doc to ArrayBuilder.arr
func (a *ArrayBuilder) AppendDocument(doc []byte) *ArrayBuilder {
a.arr = AppendDocumentElement(a.arr, a.incrementKey(), doc)
return a
}
// AppendArray will append arr to ArrayBuilder.arr
func (a *ArrayBuilder) AppendArray(arr []byte) *ArrayBuilder {
a.arr = AppendArrayElement(a.arr, a.incrementKey(), arr)
return a
}
// AppendDouble will append f to ArrayBuilder.doc
func (a *ArrayBuilder) AppendDouble(f float64) *ArrayBuilder {
a.arr = AppendDoubleElement(a.arr, a.incrementKey(), f)
return a
}
// AppendString will append str to ArrayBuilder.doc
func (a *ArrayBuilder) AppendString(str string) *ArrayBuilder {
a.arr = AppendStringElement(a.arr, a.incrementKey(), str)
return a
}
// AppendObjectID will append oid to ArrayBuilder.doc
func (a *ArrayBuilder) AppendObjectID(oid objectID) *ArrayBuilder {
a.arr = AppendObjectIDElement(a.arr, a.incrementKey(), oid)
return a
}
// AppendBinary will append a BSON binary element using subtype, and
// b to a.arr
func (a *ArrayBuilder) AppendBinary(subtype byte, b []byte) *ArrayBuilder {
a.arr = AppendBinaryElement(a.arr, a.incrementKey(), subtype, b)
return a
}
// AppendUndefined will append a BSON undefined element using key to a.arr
func (a *ArrayBuilder) AppendUndefined() *ArrayBuilder {
a.arr = AppendUndefinedElement(a.arr, a.incrementKey())
return a
}
// AppendBoolean will append a boolean element using b to a.arr
func (a *ArrayBuilder) AppendBoolean(b bool) *ArrayBuilder {
a.arr = AppendBooleanElement(a.arr, a.incrementKey(), b)
return a
}
// AppendDateTime will append datetime element dt to a.arr
func (a *ArrayBuilder) AppendDateTime(dt int64) *ArrayBuilder {
a.arr = AppendDateTimeElement(a.arr, a.incrementKey(), dt)
return a
}
// AppendNull will append a null element to a.arr
func (a *ArrayBuilder) AppendNull() *ArrayBuilder {
a.arr = AppendNullElement(a.arr, a.incrementKey())
return a
}
// AppendRegex will append pattern and options to a.arr
func (a *ArrayBuilder) AppendRegex(pattern, options string) *ArrayBuilder {
a.arr = AppendRegexElement(a.arr, a.incrementKey(), pattern, options)
return a
}
// AppendDBPointer will append ns and oid to a.arr
func (a *ArrayBuilder) AppendDBPointer(ns string, oid objectID) *ArrayBuilder {
a.arr = AppendDBPointerElement(a.arr, a.incrementKey(), ns, oid)
return a
}
// AppendJavaScript will append js to a.arr
func (a *ArrayBuilder) AppendJavaScript(js string) *ArrayBuilder {
a.arr = AppendJavaScriptElement(a.arr, a.incrementKey(), js)
return a
}
// AppendSymbol will append symbol to a.arr
func (a *ArrayBuilder) AppendSymbol(symbol string) *ArrayBuilder {
a.arr = AppendSymbolElement(a.arr, a.incrementKey(), symbol)
return a
}
// AppendCodeWithScope will append code and scope to a.arr
func (a *ArrayBuilder) AppendCodeWithScope(code string, scope Document) *ArrayBuilder {
a.arr = AppendCodeWithScopeElement(a.arr, a.incrementKey(), code, scope)
return a
}
// AppendTimestamp will append t and i to a.arr
func (a *ArrayBuilder) AppendTimestamp(t, i uint32) *ArrayBuilder {
a.arr = AppendTimestampElement(a.arr, a.incrementKey(), t, i)
return a
}
// AppendInt64 will append i64 to a.arr
func (a *ArrayBuilder) AppendInt64(i64 int64) *ArrayBuilder {
a.arr = AppendInt64Element(a.arr, a.incrementKey(), i64)
return a
}
// AppendDecimal128 will append d128 to a.arr
func (a *ArrayBuilder) AppendDecimal128(high, low uint64) *ArrayBuilder {
a.arr = AppendDecimal128Element(a.arr, a.incrementKey(), high, low)
return a
}
// AppendMaxKey will append a max key element to a.arr
func (a *ArrayBuilder) AppendMaxKey() *ArrayBuilder {
a.arr = AppendMaxKeyElement(a.arr, a.incrementKey())
return a
}
// AppendMinKey will append a min key element to a.arr
func (a *ArrayBuilder) AppendMinKey() *ArrayBuilder {
a.arr = AppendMinKeyElement(a.arr, a.incrementKey())
return a
}
// AppendValue appends a BSON value to the array.
func (a *ArrayBuilder) AppendValue(val Value) *ArrayBuilder {
a.arr = AppendValueElement(a.arr, a.incrementKey(), val)
return a
}
// StartArray starts building an inline Array. After this document is completed,
// the user must call a.FinishArray
func (a *ArrayBuilder) StartArray() *ArrayBuilder {
a.arr = AppendHeader(a.arr, TypeArray, a.incrementKey())
a.startArray()
return a
}
// FinishArray builds the most recent array created
func (a *ArrayBuilder) FinishArray() *ArrayBuilder {
a.arr = a.Build()
return a
}

View File

@@ -0,0 +1,184 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
// DocumentBuilder builds a bson document
type DocumentBuilder struct {
doc []byte
indexes []int32
}
// startDocument reserves the document's length and set the index to where the length begins
func (db *DocumentBuilder) startDocument() *DocumentBuilder {
var index int32
index, db.doc = AppendDocumentStart(db.doc)
db.indexes = append(db.indexes, index)
return db
}
// NewDocumentBuilder creates a new DocumentBuilder
func NewDocumentBuilder() *DocumentBuilder {
return (&DocumentBuilder{}).startDocument()
}
// Build updates the length of the document and index to the beginning of the documents length
// bytes, then returns the document (bson bytes)
func (db *DocumentBuilder) Build() Document {
last := len(db.indexes) - 1
db.doc, _ = AppendDocumentEnd(db.doc, db.indexes[last])
db.indexes = db.indexes[:last]
return db.doc
}
// AppendInt32 will append an int32 element using key and i32 to DocumentBuilder.doc
func (db *DocumentBuilder) AppendInt32(key string, i32 int32) *DocumentBuilder {
db.doc = AppendInt32Element(db.doc, key, i32)
return db
}
// AppendDocument will append a bson embedded document element using key
// and doc to DocumentBuilder.doc
func (db *DocumentBuilder) AppendDocument(key string, doc []byte) *DocumentBuilder {
db.doc = AppendDocumentElement(db.doc, key, doc)
return db
}
// AppendArray will append a bson array using key and arr to DocumentBuilder.doc
func (db *DocumentBuilder) AppendArray(key string, arr []byte) *DocumentBuilder {
db.doc = AppendHeader(db.doc, TypeArray, key)
db.doc = AppendArray(db.doc, arr)
return db
}
// AppendDouble will append a double element using key and f to DocumentBuilder.doc
func (db *DocumentBuilder) AppendDouble(key string, f float64) *DocumentBuilder {
db.doc = AppendDoubleElement(db.doc, key, f)
return db
}
// AppendString will append str to DocumentBuilder.doc with the given key
func (db *DocumentBuilder) AppendString(key string, str string) *DocumentBuilder {
db.doc = AppendStringElement(db.doc, key, str)
return db
}
// AppendObjectID will append oid to DocumentBuilder.doc with the given key
func (db *DocumentBuilder) AppendObjectID(key string, oid objectID) *DocumentBuilder {
db.doc = AppendObjectIDElement(db.doc, key, oid)
return db
}
// AppendBinary will append a BSON binary element using key, subtype, and
// b to db.doc
func (db *DocumentBuilder) AppendBinary(key string, subtype byte, b []byte) *DocumentBuilder {
db.doc = AppendBinaryElement(db.doc, key, subtype, b)
return db
}
// AppendUndefined will append a BSON undefined element using key to db.doc
func (db *DocumentBuilder) AppendUndefined(key string) *DocumentBuilder {
db.doc = AppendUndefinedElement(db.doc, key)
return db
}
// AppendBoolean will append a boolean element using key and b to db.doc
func (db *DocumentBuilder) AppendBoolean(key string, b bool) *DocumentBuilder {
db.doc = AppendBooleanElement(db.doc, key, b)
return db
}
// AppendDateTime will append a datetime element using key and dt to db.doc
func (db *DocumentBuilder) AppendDateTime(key string, dt int64) *DocumentBuilder {
db.doc = AppendDateTimeElement(db.doc, key, dt)
return db
}
// AppendNull will append a null element using key to db.doc
func (db *DocumentBuilder) AppendNull(key string) *DocumentBuilder {
db.doc = AppendNullElement(db.doc, key)
return db
}
// AppendRegex will append pattern and options using key to db.doc
func (db *DocumentBuilder) AppendRegex(key, pattern, options string) *DocumentBuilder {
db.doc = AppendRegexElement(db.doc, key, pattern, options)
return db
}
// AppendDBPointer will append ns and oid to using key to db.doc
func (db *DocumentBuilder) AppendDBPointer(key string, ns string, oid objectID) *DocumentBuilder {
db.doc = AppendDBPointerElement(db.doc, key, ns, oid)
return db
}
// AppendJavaScript will append js using the provided key to db.doc
func (db *DocumentBuilder) AppendJavaScript(key, js string) *DocumentBuilder {
db.doc = AppendJavaScriptElement(db.doc, key, js)
return db
}
// AppendSymbol will append a BSON symbol element using key and symbol db.doc
func (db *DocumentBuilder) AppendSymbol(key, symbol string) *DocumentBuilder {
db.doc = AppendSymbolElement(db.doc, key, symbol)
return db
}
// AppendCodeWithScope will append code and scope using key to db.doc
func (db *DocumentBuilder) AppendCodeWithScope(key string, code string, scope Document) *DocumentBuilder {
db.doc = AppendCodeWithScopeElement(db.doc, key, code, scope)
return db
}
// AppendTimestamp will append t and i to db.doc using provided key
func (db *DocumentBuilder) AppendTimestamp(key string, t, i uint32) *DocumentBuilder {
db.doc = AppendTimestampElement(db.doc, key, t, i)
return db
}
// AppendInt64 will append i64 to dst using key to db.doc
func (db *DocumentBuilder) AppendInt64(key string, i64 int64) *DocumentBuilder {
db.doc = AppendInt64Element(db.doc, key, i64)
return db
}
// AppendDecimal128 will append d128 to db.doc using provided key
func (db *DocumentBuilder) AppendDecimal128(key string, high, low uint64) *DocumentBuilder {
db.doc = AppendDecimal128Element(db.doc, key, high, low)
return db
}
// AppendMaxKey will append a max key element using key to db.doc
func (db *DocumentBuilder) AppendMaxKey(key string) *DocumentBuilder {
db.doc = AppendMaxKeyElement(db.doc, key)
return db
}
// AppendMinKey will append a min key element using key to db.doc
func (db *DocumentBuilder) AppendMinKey(key string) *DocumentBuilder {
db.doc = AppendMinKeyElement(db.doc, key)
return db
}
// AppendValue will append a BSON element with the provided key and value to the document.
func (db *DocumentBuilder) AppendValue(key string, val Value) *DocumentBuilder {
db.doc = AppendValueElement(db.doc, key, val)
return db
}
// StartDocument starts building an inline document element with the provided key
// After this document is completed, the user must call finishDocument
func (db *DocumentBuilder) StartDocument(key string) *DocumentBuilder {
db.doc = AppendHeader(db.doc, TypeEmbeddedDocument, key)
db = db.startDocument()
return db
}
// FinishDocument builds the most recent document created
func (db *DocumentBuilder) FinishDocument() *DocumentBuilder {
db.doc = db.Build()
return db
}

View File

@@ -0,0 +1,773 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"strconv"
"strings"
"time"
"go.mongodb.org/mongo-driver/v2/internal/binaryutil"
)
const (
// EmptyDocumentLength is the length of a document that has been started/ended but has no elements.
EmptyDocumentLength = 5
// nullTerminator is a string version of the 0 byte that is appended at the end of cstrings.
nullTerminator = string(byte(0))
invalidKeyPanicMsg = "BSON element keys cannot contain null bytes"
invalidRegexPanicMsg = "BSON regex values cannot contain null bytes"
)
type objectID = [12]byte
// AppendType will append t to dst and return the extended buffer.
func AppendType(dst []byte, t Type) []byte { return append(dst, byte(t)) }
// AppendKey will append key to dst and return the extended buffer.
func AppendKey(dst []byte, key string) []byte { return append(dst, key+nullTerminator...) }
// AppendHeader will append Type t and key to dst and return the extended
// buffer.
func AppendHeader(dst []byte, t Type, key string) []byte {
if !isValidCString(key) {
panic(invalidKeyPanicMsg)
}
dst = AppendType(dst, t)
dst = append(dst, key...)
return append(dst, 0x00)
// return append(AppendType(dst, t), key+string(0x00)...)
}
// TODO(skriptble): All of the Read* functions should return src resliced to start just after what was read.
// ReadType will return the first byte of the provided []byte as a type. If
// there is no available byte, false is returned.
func ReadType(src []byte) (Type, []byte, bool) {
if len(src) < 1 {
return 0, src, false
}
return Type(src[0]), src[1:], true
}
// ReadKey will read a key from src. The 0x00 byte will not be present
// in the returned string. If there are not enough bytes available, false is
// returned.
func ReadKey(src []byte) (string, []byte, bool) { return binaryutil.ReadCString(src) }
// ReadKeyBytes will read a key from src as bytes. The 0x00 byte will
// not be present in the returned string. If there are not enough bytes
// available, false is returned.
func ReadKeyBytes(src []byte) ([]byte, []byte, bool) { return binaryutil.ReadCStringBytes(src) }
// ReadHeader will read a type byte and a key from src. If both of these
// values cannot be read, false is returned.
func ReadHeader(src []byte) (t Type, key string, rem []byte, ok bool) {
t, rem, ok = ReadType(src)
if !ok {
return 0, "", src, false
}
key, rem, ok = ReadKey(rem)
if !ok {
return 0, "", src, false
}
return t, key, rem, true
}
// ReadHeaderBytes will read a type and a key from src and the remainder of the bytes
// are returned as rem. If either the type or key cannot be red, ok will be false.
func ReadHeaderBytes(src []byte) (header []byte, rem []byte, ok bool) {
if len(src) < 1 {
return nil, src, false
}
idx := bytes.IndexByte(src[1:], 0x00)
if idx == -1 {
return nil, src, false
}
return src[:idx], src[idx+1:], true
}
// ReadElement reads the next full element from src. It returns the element, the remaining bytes in
// the slice, and a boolean indicating if the read was successful.
func ReadElement(src []byte) (Element, []byte, bool) {
if len(src) < 1 {
return nil, src, false
}
t := Type(src[0])
idx := 1
for idx < len(src) && src[idx] != 0x00 {
idx++
}
if idx >= len(src) {
return nil, src, false
}
idx++ // Move past the null byte
length, ok := valueLength(src[idx:], t)
if !ok {
return nil, src, false
}
elemLength := idx + int(length)
if elemLength > len(src) {
return nil, src, false
}
return src[:elemLength], src[elemLength:], true
}
// AppendValueElement appends value to dst as an element using key as the element's key.
func AppendValueElement(dst []byte, key string, value Value) []byte {
dst = AppendHeader(dst, value.Type, key)
dst = append(dst, value.Data...)
return dst
}
// ReadValue reads the next value as the provided types and returns a Value, the remaining bytes,
// and a boolean indicating if the read was successful.
func ReadValue(src []byte, t Type) (Value, []byte, bool) {
data, rem, ok := readValue(src, t)
if !ok {
return Value{}, src, false
}
return Value{Type: t, Data: data}, rem, true
}
// AppendDouble will append f to dst and return the extended buffer.
func AppendDouble(dst []byte, f float64) []byte {
return binaryutil.Append64(dst, math.Float64bits(f))
}
// AppendDoubleElement will append a BSON double element using key and f to dst
// and return the extended buffer.
func AppendDoubleElement(dst []byte, key string, f float64) []byte {
return AppendDouble(AppendHeader(dst, TypeDouble, key), f)
}
// ReadDouble will read a float64 from src. If there are not enough bytes it
// will return false.
func ReadDouble(src []byte) (float64, []byte, bool) {
bits, src, ok := binaryutil.ReadU64(src)
if !ok {
return 0, src, false
}
return math.Float64frombits(bits), src, true
}
// AppendString will append s to dst and return the extended buffer.
func AppendString(dst []byte, s string) []byte {
return appendstring(dst, s)
}
// AppendStringElement will append a BSON string element using key and val to dst
// and return the extended buffer.
func AppendStringElement(dst []byte, key, val string) []byte {
return AppendString(AppendHeader(dst, TypeString, key), val)
}
// ReadString will read a string from src. If there are not enough bytes it
// will return false.
func ReadString(src []byte) (string, []byte, bool) {
return readstring(src)
}
// AppendDocumentStart reserves a document's length and returns the index where the length begins.
// This index can later be used to write the length of the document.
func AppendDocumentStart(dst []byte) (index int32, b []byte) {
// TODO(skriptble): We really need AppendDocumentStart and AppendDocumentEnd. AppendDocumentStart would handle calling
// TODO ReserveLength and providing the index of the start of the document. AppendDocumentEnd would handle taking that
// TODO start index, adding the null byte, calculating the length, and filling in the length at the start of the
// TODO document.
return ReserveLength(dst)
}
// AppendDocumentStartInline functions the same as AppendDocumentStart but takes a pointer to the
// index int32 which allows this function to be used inline.
func AppendDocumentStartInline(dst []byte, index *int32) []byte {
idx, doc := AppendDocumentStart(dst)
*index = idx
return doc
}
// AppendDocumentElementStart writes a document element header and then reserves the length bytes.
func AppendDocumentElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendDocumentStart(AppendHeader(dst, TypeEmbeddedDocument, key))
}
// AppendDocumentEnd writes the null byte for a document and updates the length of the document.
// The index should be the beginning of the document's length bytes.
func AppendDocumentEnd(dst []byte, index int32) ([]byte, error) {
if int(index) > len(dst)-4 {
return dst, fmt.Errorf("not enough bytes available after index to write length")
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, index, int32(len(dst[index:])))
return dst, nil
}
// AppendDocument will append doc to dst and return the extended buffer.
func AppendDocument(dst []byte, doc []byte) []byte { return append(dst, doc...) }
// AppendDocumentElement will append a BSON embedded document element using key
// and doc to dst and return the extended buffer.
func AppendDocumentElement(dst []byte, key string, doc []byte) []byte {
return AppendDocument(AppendHeader(dst, TypeEmbeddedDocument, key), doc)
}
// BuildDocument will create a document with the given slice of elements and will append
// it to dst and return the extended buffer.
func BuildDocument(dst []byte, elems ...[]byte) []byte {
idx, dst := ReserveLength(dst)
for _, elem := range elems {
dst = append(dst, elem...)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildDocumentValue creates an Embedded Document value from the given elements.
func BuildDocumentValue(elems ...[]byte) Value {
return Value{Type: TypeEmbeddedDocument, Data: BuildDocument(nil, elems...)}
}
// BuildDocumentElement will append a BSON embedded document element using key and the provided
// elements and return the extended buffer.
func BuildDocumentElement(dst []byte, key string, elems ...[]byte) []byte {
return BuildDocument(AppendHeader(dst, TypeEmbeddedDocument, key), elems...)
}
// BuildDocumentFromElements is an alaias for the BuildDocument function.
var BuildDocumentFromElements = BuildDocument
// ReadDocument will read a document from src. If there are not enough bytes it
// will return false.
func ReadDocument(src []byte) (doc Document, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendArrayStart appends the length bytes to an array and then returns the index of the start
// of those length bytes.
func AppendArrayStart(dst []byte) (index int32, b []byte) { return ReserveLength(dst) }
// AppendArrayElementStart appends an array element header and then the length bytes for an array,
// returning the index where the length starts.
func AppendArrayElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendArrayStart(AppendHeader(dst, TypeArray, key))
}
// AppendArrayEnd appends the null byte to an array and calculates the length, inserting that
// calculated length starting at index.
func AppendArrayEnd(dst []byte, index int32) ([]byte, error) { return AppendDocumentEnd(dst, index) }
// AppendArray will append arr to dst and return the extended buffer.
func AppendArray(dst []byte, arr []byte) []byte { return append(dst, arr...) }
// AppendArrayElement will append a BSON array element using key and arr to dst
// and return the extended buffer.
func AppendArrayElement(dst []byte, key string, arr []byte) []byte {
return AppendArray(AppendHeader(dst, TypeArray, key), arr)
}
// BuildArray will append a BSON array to dst built from values.
func BuildArray(dst []byte, values ...Value) []byte {
idx, dst := ReserveLength(dst)
for pos, val := range values {
dst = AppendValueElement(dst, strconv.Itoa(pos), val)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildArrayElement will create an array element using the provided values.
func BuildArrayElement(dst []byte, key string, values ...Value) []byte {
return BuildArray(AppendHeader(dst, TypeArray, key), values...)
}
// ReadArray will read an array from src. If there are not enough bytes it
// will return false.
func ReadArray(src []byte) (arr Array, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendBinary will append subtype and b to dst and return the extended buffer.
func AppendBinary(dst []byte, subtype byte, b []byte) []byte {
if subtype == 0x02 {
return appendBinarySubtype2(dst, subtype, b)
}
dst = append(appendLength(dst, int32(len(b))), subtype)
return append(dst, b...)
}
// AppendBinaryElement will append a BSON binary element using key, subtype, and
// b to dst and return the extended buffer.
func AppendBinaryElement(dst []byte, key string, subtype byte, b []byte) []byte {
return AppendBinary(AppendHeader(dst, TypeBinary, key), subtype, b)
}
// ReadBinary will read a subtype and bin from src. If there are not enough bytes it
// will return false.
func ReadBinary(src []byte) (subtype byte, bin []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok {
return 0x00, nil, src, false
}
if len(rem) < 1 { // subtype
return 0x00, nil, src, false
}
subtype, rem = rem[0], rem[1:]
if len(rem) < int(length) {
return 0x00, nil, src, false
}
if subtype == 0x02 {
length, rem, ok = ReadLength(rem)
if !ok || len(rem) < int(length) {
return 0x00, nil, src, false
}
}
return subtype, rem[:length], rem[length:], true
}
// AppendUndefinedElement will append a BSON undefined element using key to dst
// and return the extended buffer.
func AppendUndefinedElement(dst []byte, key string) []byte {
return AppendHeader(dst, TypeUndefined, key)
}
// AppendObjectID will append oid to dst and return the extended buffer.
func AppendObjectID(dst []byte, oid objectID) []byte { return append(dst, oid[:]...) }
// AppendObjectIDElement will append a BSON ObjectID element using key and oid to dst
// and return the extended buffer.
func AppendObjectIDElement(dst []byte, key string, oid objectID) []byte {
return AppendObjectID(AppendHeader(dst, TypeObjectID, key), oid)
}
// ReadObjectID will read an ObjectID from src. If there are not enough bytes it
// will return false.
func ReadObjectID(src []byte) ([12]byte, []byte, bool) {
var oid objectID
idLen := cap(oid)
if len(src) < idLen {
return oid, src, false
}
copy(oid[:], src[0:idLen])
return oid, src[idLen:], true
}
// AppendBoolean will append b to dst and return the extended buffer.
func AppendBoolean(dst []byte, b bool) []byte {
if b {
return append(dst, 0x01)
}
return append(dst, 0x00)
}
// AppendBooleanElement will append a BSON boolean element using key and b to dst
// and return the extended buffer.
func AppendBooleanElement(dst []byte, key string, b bool) []byte {
return AppendBoolean(AppendHeader(dst, TypeBoolean, key), b)
}
// ReadBoolean will read a bool from src. If there are not enough bytes it
// will return false.
func ReadBoolean(src []byte) (bool, []byte, bool) {
if len(src) < 1 {
return false, src, false
}
return src[0] == 0x01, src[1:], true
}
// AppendDateTime will append dt to dst and return the extended buffer.
func AppendDateTime(dst []byte, dt int64) []byte { return binaryutil.Append64(dst, dt) }
// AppendDateTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendDateTimeElement(dst []byte, key string, dt int64) []byte {
return AppendDateTime(AppendHeader(dst, TypeDateTime, key), dt)
}
// ReadDateTime will read an int64 datetime from src. If there are not enough bytes it
// will return false.
func ReadDateTime(src []byte) (int64, []byte, bool) { return binaryutil.ReadI64(src) }
// AppendTime will append time as a BSON DateTime to dst and return the extended buffer.
func AppendTime(dst []byte, t time.Time) []byte {
return AppendDateTime(dst, t.Unix()*1000+int64(t.Nanosecond()/1e6))
}
// AppendTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendTimeElement(dst []byte, key string, t time.Time) []byte {
return AppendTime(AppendHeader(dst, TypeDateTime, key), t)
}
// ReadTime will read an time.Time datetime from src. If there are not enough bytes it
// will return false.
func ReadTime(src []byte) (time.Time, []byte, bool) {
dt, rem, ok := binaryutil.ReadI64(src)
return time.Unix(dt/1e3, dt%1e3*1e6), rem, ok
}
// AppendNullElement will append a BSON null element using key to dst
// and return the extended buffer.
func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst, TypeNull, key) }
// AppendRegex will append pattern and options to dst and return the extended buffer.
func AppendRegex(dst []byte, pattern, options string) []byte {
if !isValidCString(pattern) || !isValidCString(options) {
panic(invalidRegexPanicMsg)
}
return append(dst, pattern+nullTerminator+options+nullTerminator...)
}
// AppendRegexElement will append a BSON regex element using key, pattern, and
// options to dst and return the extended buffer.
func AppendRegexElement(dst []byte, key, pattern, options string) []byte {
return AppendRegex(AppendHeader(dst, TypeRegex, key), pattern, options)
}
// ReadRegex will read a pattern and options from src. If there are not enough bytes it
// will return false.
func ReadRegex(src []byte) (pattern, options string, rem []byte, ok bool) {
pattern, rem, ok = binaryutil.ReadCString(src)
if !ok {
return "", "", src, false
}
options, rem, ok = binaryutil.ReadCString(rem)
if !ok {
return "", "", src, false
}
return pattern, options, rem, true
}
// AppendDBPointer will append ns and oid to dst and return the extended buffer.
func AppendDBPointer(dst []byte, ns string, oid objectID) []byte {
return append(appendstring(dst, ns), oid[:]...)
}
// AppendDBPointerElement will append a BSON DBPointer element using key, ns,
// and oid to dst and return the extended buffer.
func AppendDBPointerElement(dst []byte, key, ns string, oid objectID) []byte {
return AppendDBPointer(AppendHeader(dst, TypeDBPointer, key), ns, oid)
}
// ReadDBPointer will read a ns and oid from src. If there are not enough bytes it
// will return false.
func ReadDBPointer(src []byte) (ns string, oid [12]byte, rem []byte, ok bool) {
ns, rem, ok = readstring(src)
if !ok {
return "", objectID{}, src, false
}
oid, rem, ok = ReadObjectID(rem)
if !ok {
return "", objectID{}, src, false
}
return ns, oid, rem, true
}
// AppendJavaScript will append js to dst and return the extended buffer.
func AppendJavaScript(dst []byte, js string) []byte { return appendstring(dst, js) }
// AppendJavaScriptElement will append a BSON JavaScript element using key and
// js to dst and return the extended buffer.
func AppendJavaScriptElement(dst []byte, key, js string) []byte {
return AppendJavaScript(AppendHeader(dst, TypeJavaScript, key), js)
}
// ReadJavaScript will read a js string from src. If there are not enough bytes it
// will return false.
func ReadJavaScript(src []byte) (js string, rem []byte, ok bool) { return readstring(src) }
// AppendSymbol will append symbol to dst and return the extended buffer.
func AppendSymbol(dst []byte, symbol string) []byte { return appendstring(dst, symbol) }
// AppendSymbolElement will append a BSON symbol element using key and symbol to dst
// and return the extended buffer.
func AppendSymbolElement(dst []byte, key, symbol string) []byte {
return AppendSymbol(AppendHeader(dst, TypeSymbol, key), symbol)
}
// ReadSymbol will read a symbol string from src. If there are not enough bytes it
// will return false.
func ReadSymbol(src []byte) (symbol string, rem []byte, ok bool) { return readstring(src) }
// AppendCodeWithScope will append code and scope to dst and return the extended buffer.
func AppendCodeWithScope(dst []byte, code string, scope []byte) []byte {
length := int32(4 + 4 + len(code) + 1 + len(scope)) // length of cws, length of code, code, 0x00, scope
dst = appendLength(dst, length)
return append(appendstring(dst, code), scope...)
}
// AppendCodeWithScopeElement will append a BSON code with scope element using
// key, code, and scope to dst
// and return the extended buffer.
func AppendCodeWithScopeElement(dst []byte, key, code string, scope []byte) []byte {
return AppendCodeWithScope(AppendHeader(dst, TypeCodeWithScope, key), code, scope)
}
// ReadCodeWithScope will read code and scope from src. If there are not enough bytes it
// will return false.
func ReadCodeWithScope(src []byte) (code string, scope []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok || len(src) < int(length) {
return "", nil, src, false
}
code, rem, ok = readstring(rem)
if !ok {
return "", nil, src, false
}
scope, rem, ok = ReadDocument(rem)
if !ok {
return "", nil, src, false
}
return code, scope, rem, true
}
// AppendInt32 will append i32 to dst and return the extended buffer.
func AppendInt32(dst []byte, i32 int32) []byte { return binaryutil.Append32(dst, i32) }
// AppendInt32Element will append a BSON int32 element using key and i32 to dst
// and return the extended buffer.
func AppendInt32Element(dst []byte, key string, i32 int32) []byte {
return AppendInt32(AppendHeader(dst, TypeInt32, key), i32)
}
// ReadInt32 will read an int32 from src. If there are not enough bytes it
// will return false.
func ReadInt32(src []byte) (int32, []byte, bool) { return binaryutil.ReadI32(src) }
// AppendTimestamp will append t and i to dst and return the extended buffer.
func AppendTimestamp(dst []byte, t, i uint32) []byte {
return binaryutil.Append32(binaryutil.Append32(dst, i), t) // i is the lower 4 bytes, t is the higher 4 bytes
}
// AppendTimestampElement will append a BSON timestamp element using key, t, and
// i to dst and return the extended buffer.
func AppendTimestampElement(dst []byte, key string, t, i uint32) []byte {
return AppendTimestamp(AppendHeader(dst, TypeTimestamp, key), t, i)
}
// ReadTimestamp will read t and i from src. If there are not enough bytes it
// will return false.
func ReadTimestamp(src []byte) (t, i uint32, rem []byte, ok bool) {
i, rem, ok = binaryutil.ReadU32(src)
if !ok {
return 0, 0, src, false
}
t, rem, ok = binaryutil.ReadU32(rem)
if !ok {
return 0, 0, src, false
}
return t, i, rem, true
}
// AppendInt64 will append i64 to dst and return the extended buffer.
func AppendInt64(dst []byte, i64 int64) []byte { return binaryutil.Append64(dst, i64) }
// AppendInt64Element will append a BSON int64 element using key and i64 to dst
// and return the extended buffer.
func AppendInt64Element(dst []byte, key string, i64 int64) []byte {
return AppendInt64(AppendHeader(dst, TypeInt64, key), i64)
}
// ReadInt64 will read an int64 from src. If there are not enough bytes it
// will return false.
func ReadInt64(src []byte) (int64, []byte, bool) { return binaryutil.ReadI64(src) }
// AppendDecimal128 will append high and low parts of a d128 to dst and return the extended buffer.
func AppendDecimal128(dst []byte, high, low uint64) []byte {
return binaryutil.Append64(binaryutil.Append64(dst, low), high)
}
// AppendDecimal128Element will append high and low parts of a BSON bson.Decimal128 element using key and
// d128 to dst and return the extended buffer.
func AppendDecimal128Element(dst []byte, key string, high, low uint64) []byte {
return AppendDecimal128(AppendHeader(dst, TypeDecimal128, key), high, low)
}
// ReadDecimal128 will read high and low parts of a bson.Decimal128 from src. If there are not enough bytes it
// will return false.
func ReadDecimal128(src []byte) (high uint64, low uint64, rem []byte, ok bool) {
low, rem, ok = binaryutil.ReadU64(src)
if !ok {
return 0, 0, src, false
}
high, rem, ok = binaryutil.ReadU64(rem)
if !ok {
return 0, 0, src, false
}
return high, low, rem, true
}
// AppendMaxKeyElement will append a BSON max key element using key to dst
// and return the extended buffer.
func AppendMaxKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, TypeMaxKey, key)
}
// AppendMinKeyElement will append a BSON min key element using key to dst
// and return the extended buffer.
func AppendMinKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, TypeMinKey, key)
}
// EqualValue will return true if the two values are equal.
func EqualValue(t1, t2 Type, v1, v2 []byte) bool {
if t1 != t2 {
return false
}
v1, _, ok := readValue(v1, t1)
if !ok {
return false
}
v2, _, ok = readValue(v2, t2)
if !ok {
return false
}
return bytes.Equal(v1, v2)
}
// valueLength will determine the length of the next value contained in src as if it
// is type t. The returned bool will be false if there are not enough bytes in src for
// a value of type t.
func valueLength(src []byte, t Type) (int32, bool) {
var length int32
ok := true
switch t {
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
length, _, ok = ReadLength(src)
case TypeBinary:
length, _, ok = ReadLength(src)
length += 4 + 1 // binary length + subtype byte
case TypeBoolean:
length = 1
case TypeDBPointer:
length, _, ok = ReadLength(src)
length += 4 + 12 // string length + ObjectID length
case TypeDateTime, TypeDouble, TypeInt64, TypeTimestamp:
length = 8
case TypeDecimal128:
length = 16
case TypeInt32:
length = 4
case TypeJavaScript, TypeString, TypeSymbol:
length, _, ok = ReadLength(src)
length += 4
case TypeMaxKey, TypeMinKey, TypeNull, TypeUndefined:
length = 0
case TypeObjectID:
length = 12
case TypeRegex:
regex := bytes.IndexByte(src, 0x00)
if regex < 0 {
ok = false
break
}
pattern := bytes.IndexByte(src[regex+1:], 0x00)
if pattern < 0 {
ok = false
break
}
length = int32(int64(regex) + 1 + int64(pattern) + 1)
default:
ok = false
}
return length, ok
}
func readValue(src []byte, t Type) ([]byte, []byte, bool) {
length, ok := valueLength(src, t)
if !ok || int(length) > len(src) {
return nil, src, false
}
return src[:length], src[length:], true
}
// ReserveLength reserves the space required for length and returns the index where to write the length
// and the []byte with reserved space.
func ReserveLength(dst []byte) (int32, []byte) {
index := len(dst)
return int32(index), append(dst, 0x00, 0x00, 0x00, 0x00)
}
// UpdateLength updates the length at index with length and returns the []byte.
func UpdateLength(dst []byte, index, length int32) []byte {
binary.LittleEndian.PutUint32(dst[index:], uint32(length))
return dst
}
func appendLength(dst []byte, l int32) []byte { return binaryutil.Append32(dst, l) }
// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
// there aren't enough bytes to read a valid length, src is returned unomdified and the returned
// bool will be false.
func ReadLength(src []byte) (int32, []byte, bool) {
ln, src, ok := binaryutil.ReadI32(src)
if ln < 0 {
return ln, src, false
}
return ln, src, ok
}
func appendstring(dst []byte, s string) []byte {
l := int32(len(s) + 1)
dst = appendLength(dst, l)
dst = append(dst, s...)
return append(dst, 0x00)
}
func readstring(src []byte) (string, []byte, bool) {
l, rem, ok := ReadLength(src)
if !ok {
return "", src, false
}
if len(src[4:]) < int(l) || l == 0 {
return "", src, false
}
return string(rem[:l-1]), rem[l:], true
}
// readLengthBytes attempts to read a length and that number of bytes. This
// function requires that the length include the four bytes for itself.
func readLengthBytes(src []byte) ([]byte, []byte, bool) {
l, _, ok := ReadLength(src)
if !ok {
return nil, src, false
}
if l < 4 {
return nil, src, false
}
if len(src) < int(l) {
return nil, src, false
}
return src[:l], src[l:], true
}
func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte {
dst = appendLength(dst, int32(len(b)+4)) // The bytes we'll encode need to be 4 larger for the length bytes
dst = append(dst, subtype)
dst = appendLength(dst, int32(len(b)))
return append(dst, b...)
}
func isValidCString(cs string) bool {
return !strings.ContainsRune(cs, '\x00')
}

View File

@@ -0,0 +1,34 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsoncore is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
//
// Package bsoncore contains functions that can be used to encode and decode
// BSON elements and values to or from a slice of bytes. These functions are
// aimed at allowing low level manipulation of BSON and can be used to build a
// higher level BSON library.
//
// The Read* functions within this package return the values of the element and
// a boolean indicating if the values are valid. A boolean was used instead of
// an error because any error that would be returned would be the same: not
// enough bytes. This library attempts to do no validation, it will only return
// false if there are not enough bytes for an item to be read. For example, the
// ReadDocument function checks the length, if that length is larger than the
// number of bytes available, it will return false, if there are enough bytes,
// it will return those bytes and true. It is the consumers responsibility to
// validate those bytes.
//
// The Append* functions within this package will append the type value to the
// given dst slice. If the slice has enough capacity, it will not grow the
// slice. The Append*Element functions within this package operate in the same
// way, but additionally append the BSON type and the key before the value.
package bsoncore

View File

@@ -0,0 +1,431 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"errors"
"fmt"
"io"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/v2/internal/binaryutil"
)
// ValidationError is an error type returned when attempting to validate a document or array.
type ValidationError string
func (ve ValidationError) Error() string { return string(ve) }
// NewDocumentLengthError creates and returns an error for when the length of a document exceeds the
// bytes available.
func NewDocumentLengthError(length, rem int) error {
return lengthError("document", length, rem)
}
func lengthError(bufferType string, length, rem int) error {
return ValidationError(fmt.Sprintf("%v length exceeds available bytes. length=%d remainingBytes=%d",
bufferType, length, rem))
}
// InsufficientBytesError indicates that there were not enough bytes to read the next component.
type InsufficientBytesError struct {
Source []byte
Remaining []byte
}
// NewInsufficientBytesError creates a new InsufficientBytesError with the given Document and
// remaining bytes.
func NewInsufficientBytesError(src, rem []byte) InsufficientBytesError {
return InsufficientBytesError{Source: src, Remaining: rem}
}
// Error implements the error interface.
func (ibe InsufficientBytesError) Error() string {
return "too few bytes to read next component"
}
// Equal checks that err2 also is an ErrTooSmall.
func (ibe InsufficientBytesError) Equal(err2 error) bool {
switch err2.(type) {
case InsufficientBytesError:
return true
default:
return false
}
}
// InvalidDepthTraversalError is returned when attempting a recursive Lookup when one component of
// the path is neither an embedded document nor an array.
type InvalidDepthTraversalError struct {
Key string
Type Type
}
func (idte InvalidDepthTraversalError) Error() string {
return fmt.Sprintf(
"attempt to traverse into %s, but it's type is %s, not %s nor %s",
idte.Key, idte.Type, TypeEmbeddedDocument, TypeArray,
)
}
// ErrMissingNull is returned when a document or array's last byte is not null.
const ErrMissingNull ValidationError = "document or array end is missing null byte"
// ErrInvalidLength indicates that a length in a binary representation of a BSON document or array
// is invalid.
const ErrInvalidLength ValidationError = "document or array length is invalid"
// ErrNilReader indicates that an operation was attempted on a nil io.Reader.
var ErrNilReader = errors.New("nil reader")
// ErrEmptyKey indicates that no key was provided to a Lookup method.
var ErrEmptyKey = errors.New("empty key provided")
// ErrElementNotFound indicates that an Element matching a certain condition does not exist.
var ErrElementNotFound = errors.New("element not found")
// ErrOutOfBounds indicates that an index provided to access something was invalid.
var ErrOutOfBounds = errors.New("out of bounds")
// Document is a raw bytes representation of a BSON document.
type Document []byte
// NewDocumentFromReader reads a document from r. This function will only validate the length is
// correct and that the document ends with a null byte.
func NewDocumentFromReader(r io.Reader) (Document, error) {
return newBufferFromReader(r)
}
func newBufferFromReader(r io.Reader) ([]byte, error) {
if r == nil {
return nil, ErrNilReader
}
var lengthBytes [4]byte
// ReadFull guarantees that we will have read at least len(lengthBytes) if err == nil
_, err := io.ReadFull(r, lengthBytes[:])
if err != nil {
return nil, err
}
length, _, _ := binaryutil.ReadI32(lengthBytes[:]) // ignore ok since we always have enough bytes to read a length
if length < 0 {
return nil, ErrInvalidLength
}
buffer := make([]byte, length)
copy(buffer, lengthBytes[:])
_, err = io.ReadFull(r, buffer[4:])
if err != nil {
return nil, err
}
if buffer[length-1] != 0x00 {
return nil, ErrMissingNull
}
return buffer, nil
}
// Lookup searches the document, potentially recursively, for the given key. If there are multiple
// keys provided, this method will recurse down, as long as the top and intermediate nodes are
// either documents or arrays. If an error occurs or if the value doesn't exist, an empty Value is
// returned.
func (d Document) Lookup(key ...string) Value {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr is the same as Lookup, except it returns an error in addition to an empty Value.
func (d Document) LookupErr(key ...string) (Value, error) {
if len(key) < 1 {
return Value{}, ErrEmptyKey
}
length, rem, ok := ReadLength(d)
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
// We use `KeyBytes` rather than `Key` to avoid a needless string alloc.
if string(elem.KeyBytes()) != key[0] {
continue
}
if len(key) > 1 {
tt := Type(elem[0])
switch tt {
case TypeEmbeddedDocument:
val, err := elem.Value().Document().LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
case TypeArray:
// Convert to Document to continue Lookup recursion.
val, err := Document(elem.Value().Array()).LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
default:
return Value{}, InvalidDepthTraversalError{Key: elem.Key(), Type: tt}
}
}
return elem.ValueErr()
}
return Value{}, ErrElementNotFound
}
// Index searches for and retrieves the element at the given index. This method will panic if
// the document is invalid or if the index is out of bounds.
func (d Document) Index(index uint) Element {
elem, err := d.IndexErr(index)
if err != nil {
panic(err)
}
return elem
}
// IndexErr searches for and retrieves the element at the given index.
func (d Document) IndexErr(index uint) (Element, error) {
return indexErr(d, index)
}
func indexErr(b []byte, index uint) (Element, error) {
length, rem, ok := ReadLength(b)
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
length -= 4
var current uint
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
if current != index {
current++
continue
}
return elem, nil
}
return nil, ErrOutOfBounds
}
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (d Document) DebugString() string {
if len(d) < 5 {
return "<malformed>"
}
var buf strings.Builder
buf.WriteString("Document")
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
buf.WriteByte('(')
buf.WriteString(strconv.Itoa(int(length)))
length -= 4
buf.WriteString("){")
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
break
}
buf.WriteString(elem.DebugString())
}
buf.WriteByte('}')
return buf.String()
}
// String outputs an ExtendedJSON version of Document. If the document is not valid, this method
// returns an empty string.
func (d Document) String() string {
str, _ := d.StringN(-1)
return str
}
// StringN stringifies a document. If N is non-negative, it will truncate the string to N bytes.
// Otherwise, it will return the full string representation. The second return value indicates
// whether the string was truncated or not.
func (d Document) StringN(n int) (string, bool) {
length, rem, ok := ReadLength(d)
if !ok || length < 5 {
return "", false
}
length -= 4 // length bytes
length-- // final null byte
if n == 0 {
return "", true
}
var buf strings.Builder
buf.WriteByte('{')
var truncated bool
var elem Element
var str string
first := true
for length > 0 && !truncated {
needStrLen := -1
// Set needStrLen if n is positive, meaning we want to limit the string length.
if n > 0 {
// Stop stringifying if we reach the limit, that also ensures needStrLen is
// greater than 0 if we need to limit the length.
if buf.Len() >= n {
truncated = true
break
}
needStrLen = n - buf.Len()
}
// Append a comma if this is not the first element.
if !first {
buf.WriteByte(',')
// If we are truncating, we need to account for the comma in the length.
if needStrLen > 0 {
needStrLen--
if needStrLen == 0 {
truncated = true
break
}
}
}
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
// Exit on malformed element.
if !ok || length < 0 {
return "", false
}
// Delegate to StringN() on the element.
str, truncated = elem.StringN(needStrLen)
buf.WriteString(str)
first = false
}
if n <= 0 || (buf.Len() < n && !truncated) {
buf.WriteByte('}')
} else {
truncated = true
}
return buf.String(), truncated
}
// Elements returns this document as a slice of elements. The returned slice will contain valid
// elements. If the document is not valid, the elements up to the invalid point will be returned
// along with an error.
func (d Document) Elements() ([]Element, error) {
length, rem, ok := ReadLength(d)
if !ok {
return nil, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
var elems []Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return elems, NewInsufficientBytesError(d, rem)
}
if err := elem.Validate(); err != nil {
return elems, err
}
elems = append(elems, elem)
}
return elems, nil
}
// Values returns this document as a slice of values. The returned slice will contain valid values.
// If the document is not valid, the values up to the invalid point will be returned along with an
// error.
func (d Document) Values() ([]Value, error) {
return values(d)
}
func values(b []byte) ([]Value, error) {
length, rem, ok := ReadLength(b)
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
length -= 4
var elem Element
var vals []Value
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return vals, NewInsufficientBytesError(b, rem)
}
if err := elem.Value().Validate(); err != nil {
return vals, err
}
vals = append(vals, elem.Value())
}
return vals, nil
}
// Validate validates the document and ensures the elements contained within are valid.
func (d Document) Validate() error {
length, rem, ok := ReadLength(d)
if !ok {
return NewInsufficientBytesError(d, rem)
}
if int(length) > len(d) {
return NewDocumentLengthError(int(length), len(d))
}
if d[length-1] != 0x00 {
return ErrMissingNull
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return NewInsufficientBytesError(d, rem)
}
err := elem.Validate()
if err != nil {
return err
}
}
if len(rem) < 1 || rem[0] != 0x00 {
return ErrMissingNull
}
return nil
}

View File

@@ -0,0 +1,213 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/v2/internal/bsoncoreutil"
)
// MalformedElementError represents a class of errors that RawElement methods return.
type MalformedElementError string
func (mee MalformedElementError) Error() string { return string(mee) }
// ErrElementMissingKey is returned when a RawElement is missing a key.
const ErrElementMissingKey MalformedElementError = "element is missing key"
// ErrElementMissingType is returned when a RawElement is missing a type.
const ErrElementMissingType MalformedElementError = "element is missing type"
// Element is a raw bytes representation of a BSON element.
type Element []byte
// Key returns the key for this element. If the element is not valid, this method returns an empty
// string. If knowing if the element is valid is important, use KeyErr.
func (e Element) Key() string {
key, _ := e.KeyErr()
return key
}
// KeyBytes returns the key for this element as a []byte. If the element is not valid, this method
// returns an empty string. If knowing if the element is valid is important, use KeyErr. This method
// will not include the null byte at the end of the key in the slice of bytes.
func (e Element) KeyBytes() []byte {
key, _ := e.KeyBytesErr()
return key
}
// KeyErr returns the key for this element, returning an error if the element is not valid.
func (e Element) KeyErr() (string, error) {
key, err := e.KeyBytesErr()
return string(key), err
}
// KeyBytesErr returns the key for this element as a []byte, returning an error if the element is
// not valid.
func (e Element) KeyBytesErr() ([]byte, error) {
if len(e) == 0 {
return nil, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return nil, ErrElementMissingKey
}
return e[1 : idx+1], nil
}
// Validate ensures the element is a valid BSON element.
func (e Element) Validate() error {
if len(e) < 1 {
return ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return ErrElementMissingKey
}
return Value{Type: Type(e[0]), Data: e[idx+2:]}.Validate()
}
// CompareKey will compare this element's key to key. This method makes it easy to compare keys
// without needing to allocate a string. The key may be null terminated. If a valid key cannot be
// read this method will return false.
func (e Element) CompareKey(key []byte) bool {
if len(e) < 2 {
return false
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return false
}
if index := bytes.IndexByte(key, 0x00); index > -1 {
key = key[:index]
}
return bytes.Equal(e[1:idx+1], key)
}
// Value returns the value of this element. If the element is not valid, this method returns an
// empty Value. If knowing if the element is valid is important, use ValueErr.
func (e Element) Value() Value {
val, _ := e.ValueErr()
return val
}
// ValueErr returns the value for this element, returning an error if the element is not valid.
func (e Element) ValueErr() (Value, error) {
if len(e) == 0 {
return Value{}, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return Value{}, ErrElementMissingKey
}
val, rem, exists := ReadValue(e[idx+2:], Type(e[0]))
if !exists {
return Value{}, NewInsufficientBytesError(e, rem)
}
return val, nil
}
// String implements the fmt.String interface. The output will be in extended JSON format.
func (e Element) String() string {
str, _ := e.StringN(-1)
return str
}
// StringN will return values in extended JSON format that will stringify an element upto N bytes.
// If N is non-negative, it will truncate the string to N bytes. Otherwise, it will return the full
// string representation. The second return value indicates whether the string was truncated or not.
// If the element is not valid, this returns an empty string
func (e Element) StringN(n int) (string, bool) {
if len(e) == 0 {
return "", false
}
if n == 0 {
return "", true
}
if n == 1 {
return `"`, true
}
t := Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx <= 0 {
return "", false
}
key := e[1 : idx+1]
var buf strings.Builder
buf.WriteByte('"')
const suffix = `": `
switch {
case n < 0 || idx <= n-buf.Len()-len(suffix):
buf.Write(key)
buf.WriteString(suffix)
case idx < n:
buf.Write(key)
buf.WriteString(suffix[:n-idx-1])
return buf.String(), true
default:
buf.WriteString(bsoncoreutil.Truncate(string(key), n-1))
return buf.String(), true
}
needStrLen := -1
// Set needStrLen if n is positive, meaning we want to limit the string length.
if n > 0 {
// Stop stringifying if we reach the limit, that also ensures needStrLen is
// greater than 0 if we need to limit the length.
if buf.Len() >= n {
return buf.String(), true
}
needStrLen = n - buf.Len()
}
val, _, valid := ReadValue(e[idx+2:], t)
if !valid {
return "", false
}
var str string
var truncated bool
if _, ok := val.StringValueOK(); ok {
str, truncated = val.StringN(needStrLen)
} else if arr, ok := val.ArrayOK(); ok {
str, truncated = arr.StringN(needStrLen)
} else {
str = val.String()
if needStrLen > 0 && len(str) > needStrLen {
truncated = true
str = bsoncoreutil.Truncate(str, needStrLen)
}
}
buf.WriteString(str)
return buf.String(), truncated
}
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
// valid components of the element even if the entire element is not valid.
func (e Element) DebugString() string {
if len(e) == 0 {
return "<malformed>"
}
t := Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return fmt.Sprintf(`bson.Element{[%s]<malformed>}`, t)
}
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
val, _, valid := ReadValue(valBytes, t)
if !valid {
return fmt.Sprintf(`bson.Element{[%s]"%s": <malformed>}`, t, key)
}
return fmt.Sprintf(`bson.Element{[%s]"%s": %v}`, t, key, val)
}

View File

@@ -0,0 +1,113 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"errors"
"fmt"
"io"
)
// errCorruptedDocument is returned when a full document couldn't be read from
// the sequence.
var errCorruptedDocument = errors.New("invalid DocumentSequence: corrupted document")
// Iterator maintains a list of BSON values and keeps track of the current
// position in relation to its Next() method.
type Iterator struct {
List Array // List of BSON values
pos int // The position of the iterator in the list in reference to Next()
}
// Count returned the number of elements in the iterator's list.
func (iter *Iterator) Count() int {
if iter == nil {
return 0
}
_, rem, ok := ReadLength(iter.List)
if !ok {
return 0
}
var count int
for len(rem) > 1 {
_, rem, ok = ReadElement(rem)
if !ok {
return 0
}
count++
}
return count
}
// Empty returns true if the iterator's list is empty.
func (iter *Iterator) Empty() bool {
return len(iter.List) <= 5
}
// Reset will reset the iteration point for the Next method to the beginning of
// the list.
func (iter *Iterator) Reset() {
iter.pos = 0
}
// Documents traverses the list as documents and returns them. This method
// assumes that the underlying list is composed of documents and will return
// an error otherwise.
func (iter *Iterator) Documents() ([]Document, error) {
if iter == nil || len(iter.List) == 0 {
return nil, nil
}
vals, err := iter.List.Values()
if err != nil {
return nil, errCorruptedDocument
}
docs := make([]Document, 0, len(vals))
for _, v := range vals {
if v.Type != TypeEmbeddedDocument {
return nil, fmt.Errorf("invalid DocumentSequence: a non-document value was found in sequence")
}
docs = append(docs, v.Data)
}
return docs, nil
}
// Next retrieves the next value from the list and returns it. This method will
// return io.EOF when it has reached the end of the list.
func (iter *Iterator) Next() (*Value, error) {
if iter == nil || iter.pos >= len(iter.List) {
return nil, io.EOF
}
if iter.pos < 4 {
if len(iter.List) < 4 {
return nil, errCorruptedDocument
}
iter.pos = 4 // Skip the length of the document
}
rem := iter.List[iter.pos:]
if len(rem) == 1 && rem[0] == 0x00 {
return nil, io.EOF // At the end of the document
}
elem, _, ok := ReadElement(rem)
if !ok {
return nil, errCorruptedDocument
}
iter.pos += len(elem)
val := elem.Value()
return &val, nil
}

View File

@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bsoncore
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

View File

@@ -0,0 +1,85 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
// Type represents a BSON type.
type Type byte
// String returns the string representation of the BSON type's name.
func (bt Type) String() string {
switch bt {
case '\x01':
return "double"
case '\x02':
return "string"
case '\x03':
return "embedded document"
case '\x04':
return "array"
case '\x05':
return "binary"
case '\x06':
return "undefined"
case '\x07':
return "objectID"
case '\x08':
return "boolean"
case '\x09':
return "UTC datetime"
case '\x0A':
return "null"
case '\x0B':
return "regex"
case '\x0C':
return "dbPointer"
case '\x0D':
return "javascript"
case '\x0E':
return "symbol"
case '\x0F':
return "code with scope"
case '\x10':
return "32-bit integer"
case '\x11':
return "timestamp"
case '\x12':
return "64-bit integer"
case '\x13':
return "128-bit decimal"
case '\x7F':
return "max key"
case '\xFF':
return "min key"
default:
return "invalid"
}
}
// BSON element types as described in https://bsonspec.org/spec.html.
const (
TypeDouble Type = 0x01
TypeString Type = 0x02
TypeEmbeddedDocument Type = 0x03
TypeArray Type = 0x04
TypeBinary Type = 0x05
TypeUndefined Type = 0x06
TypeObjectID Type = 0x07
TypeBoolean Type = 0x08
TypeDateTime Type = 0x09
TypeNull Type = 0x0A
TypeRegex Type = 0x0B
TypeDBPointer Type = 0x0C
TypeJavaScript Type = 0x0D
TypeSymbol Type = 0x0E
TypeCodeWithScope Type = 0x0F
TypeInt32 Type = 0x10
TypeTimestamp Type = 0x11
TypeInt64 Type = 0x12
TypeDecimal128 Type = 0x13
TypeMaxKey Type = 0x7F
TypeMinKey Type = 0xFF
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,248 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"go.mongodb.org/mongo-driver/v2/mongo/address"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
const sourceExternal = "$external"
// AuthenticatorFactory constructs an authenticator.
type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error)
var authFactories = make(map[string]AuthenticatorFactory)
func init() {
RegisterAuthenticatorFactory("", newDefaultAuthenticator)
RegisterAuthenticatorFactory(SCRAMSHA1, newScramSHA1Authenticator)
RegisterAuthenticatorFactory(SCRAMSHA256, newScramSHA256Authenticator)
RegisterAuthenticatorFactory(PLAIN, newPlainAuthenticator)
RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator)
RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator)
RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator)
RegisterAuthenticatorFactory(MongoDBOIDC, newOIDCAuthenticator)
}
// CreateAuthenticator creates an authenticator.
func CreateAuthenticator(name string, cred *Cred, httpClient *http.Client) (Authenticator, error) {
// Return a custom error to indicate why auth mechanism "MONGODB-CR" is
// missing, even though it was previously available.
if strings.EqualFold(name, "MONGODB-CR") {
return nil, errors.New(`auth mechanism "MONGODB-CR" is no longer available in any supported version of MongoDB`)
}
if f, ok := authFactories[name]; ok {
return f(cred, httpClient)
}
return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
}
// RegisterAuthenticatorFactory registers the authenticator factory.
func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) {
authFactories[name] = factory
}
// HandshakeOptions packages options that can be passed to the Handshaker()
// function. DBUser is optional but must be of the form <dbname.username>;
// if non-empty, then the connection will do SASL mechanism negotiation.
type HandshakeOptions struct {
AppName string
Authenticator Authenticator
Compressors []string
DBUser string
PerformAuthentication func(description.Server) bool
ClusterClock *session.ClusterClock
ServerAPI *driver.ServerAPIOptions
LoadBalanced bool
// Fields provided by a library that wraps the Go Driver.
OuterLibraryName string
OuterLibraryVersion string
OuterLibraryPlatform string
}
type authHandshaker struct {
wrapped driver.Handshaker
options *HandshakeOptions
handshakeInfo driver.HandshakeInformation
conversation SpeculativeConversation
}
var _ driver.Handshaker = (*authHandshaker)(nil)
// GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided
// connection.
func (ah *authHandshaker) GetHandshakeInformation(
ctx context.Context,
addr address.Address,
conn *mnet.Connection,
) (driver.HandshakeInformation, error) {
if ah.wrapped != nil {
return ah.wrapped.GetHandshakeInformation(ctx, addr, conn)
}
op := operation.NewHello().
AppName(ah.options.AppName).
Compressors(ah.options.Compressors).
SASLSupportedMechs(ah.options.DBUser).
ClusterClock(ah.options.ClusterClock).
ServerAPI(ah.options.ServerAPI).
LoadBalanced(ah.options.LoadBalanced).
OuterLibraryName(ah.options.OuterLibraryName).
OuterLibraryVersion(ah.options.OuterLibraryVersion).
OuterLibraryPlatform(ah.options.OuterLibraryPlatform)
if ah.options.Authenticator != nil {
if speculativeAuth, ok := ah.options.Authenticator.(SpeculativeAuthenticator); ok {
var err error
ah.conversation, err = speculativeAuth.CreateSpeculativeConversation()
if err != nil {
return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err)
}
// It is possible for the speculative conversation to be nil even without error if the authenticator
// cannot perform speculative authentication. An example of this is MONGODB-OIDC when there is
// no AccessToken in the cache.
if ah.conversation != nil {
firstMsg, err := ah.conversation.FirstMessage()
if err != nil {
return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err)
}
op = op.SpeculativeAuthenticate(firstMsg)
}
}
}
var err error
ah.handshakeInfo, err = op.GetHandshakeInformation(ctx, addr, conn)
if err != nil {
return driver.HandshakeInformation{}, newAuthError("handshake failure", err)
}
return ah.handshakeInfo, nil
}
// FinishHandshake performs authentication for conn if necessary.
func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn *mnet.Connection) error {
performAuth := ah.options.PerformAuthentication
if performAuth == nil {
performAuth = func(serv description.Server) bool {
// Authentication is possible against all server types except arbiters
return serv.Kind != description.ServerKindRSArbiter
}
}
if performAuth(conn.Description()) && ah.options.Authenticator != nil {
cfg := &driver.AuthConfig{
Connection: conn,
ClusterClock: ah.options.ClusterClock,
HandshakeInfo: ah.handshakeInfo,
ServerAPI: ah.options.ServerAPI,
}
if err := ah.authenticate(ctx, cfg); err != nil {
return newAuthError("auth error", err)
}
}
if ah.wrapped == nil {
return nil
}
return ah.wrapped.FinishHandshake(ctx, conn)
}
func (ah *authHandshaker) authenticate(ctx context.Context, cfg *driver.AuthConfig) error {
// If the initial hello reply included a response to the speculative authentication attempt, we only need to
// conduct the remainder of the conversation.
if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil {
// Defensively ensure that the server did not include a response if speculative auth was not attempted.
if ah.conversation == nil {
return errors.New("speculative auth was not attempted but the server included a response")
}
return ah.conversation.Finish(ctx, cfg, speculativeResponse)
}
// If the server does not support speculative authentication or the first attempt was not successful, we need to
// perform authentication from scratch.
return ah.options.Authenticator.Auth(ctx, cfg)
}
// Handshaker creates a connection handshaker for the given authenticator.
func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
return &authHandshaker{
wrapped: h,
options: options,
}
}
// Config holds the information necessary to perform an authentication attempt.
type Config struct {
Connection *mnet.Connection
ClusterClock *session.ClusterClock
HandshakeInfo driver.HandshakeInformation
ServerAPI *driver.ServerAPIOptions
HTTPClient *http.Client
}
// Authenticator handles authenticating a connection.
type Authenticator = driver.Authenticator
func newAuthError(msg string, inner error) error {
return &Error{
message: msg,
inner: inner,
}
}
func newError(err error, mech string) error {
return &Error{
message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
inner: err,
}
}
// Error is an error that occurred during authentication.
type Error struct {
message string
inner error
}
func (e *Error) Error() string {
if e.inner == nil {
return e.message
}
return fmt.Sprintf("%s: %s", e.message, e.inner)
}
// Inner returns the wrapped error.
func (e *Error) Inner() error {
return e.inner
}
// Unwrap returns the underlying error.
func (e *Error) Unwrap() error {
return e.inner
}
// Message returns the message.
func (e *Error) Message() string {
return e.message
}

View File

@@ -0,0 +1,188 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strings"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/aws/credentials"
v4signer "go.mongodb.org/mongo-driver/v2/internal/aws/signer/v4"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
type clientState int
const (
clientStarting clientState = iota
clientFirst
clientFinal
clientDone
)
type awsConversation struct {
state clientState
valid bool
nonce []byte
credentials *credentials.Credentials
}
type serverMessage struct {
Nonce bson.Binary `bson:"s"`
Host string `bson:"h"`
}
const (
amzDateFormat = "20060102T150405Z"
defaultRegion = "us-east-1"
maxHostLength = 255
responceNonceLength = 64
)
// Step takes a string provided from a server (or just an empty string for the
// very first conversation step) and attempts to move the authentication
// conversation forward. It returns a string to be sent to the server or an
// error if the server message is invalid. Calling Step after a conversation
// completes is also an error.
func (ac *awsConversation) Step(challenge []byte) (response []byte, err error) {
switch ac.state {
case clientStarting:
ac.state = clientFirst
response = ac.firstMsg()
case clientFirst:
ac.state = clientFinal
response, err = ac.finalMsg(challenge)
case clientFinal:
ac.state = clientDone
ac.valid = true
default:
response, err = nil, errors.New("conversation already completed")
}
return
}
// Done returns true if the conversation is completed or has errored.
func (ac *awsConversation) Done() bool {
return ac.state == clientDone
}
// Valid returns true if the conversation successfully authenticated with the
// server, including counter-validation that the server actually has the
// user's stored credentials.
func (ac *awsConversation) Valid() bool {
return ac.valid
}
func getRegion(host string) (string, error) {
region := defaultRegion
if len(host) == 0 {
return "", errors.New("invalid STS host: empty")
}
if len(host) > maxHostLength {
return "", errors.New("invalid STS host: too large")
}
// The implicit region for sts.amazonaws.com is us-east-1
if host == "sts.amazonaws.com" {
return region, nil
}
if strings.HasPrefix(host, ".") || strings.HasSuffix(host, ".") || strings.Contains(host, "..") {
return "", errors.New("invalid STS host: empty part")
}
// If the host has multiple parts, the second part is the region
parts := strings.Split(host, ".")
if len(parts) >= 2 {
region = parts[1]
}
return region, nil
}
func (ac *awsConversation) firstMsg() []byte {
// Values are cached for use in final message parameters
ac.nonce = make([]byte, 32)
_, _ = rand.Read(ac.nonce)
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendInt32Element(msg, "p", 110)
msg = bsoncore.AppendBinaryElement(msg, "r", 0x00, ac.nonce)
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg
}
func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
var sm serverMessage
err := bson.Unmarshal(s1, &sm)
if err != nil {
return nil, err
}
// Check nonce prefix
if sm.Nonce.Subtype != 0x00 {
return nil, errors.New("server reply contained unexpected binary subtype")
}
if len(sm.Nonce.Data) != responceNonceLength {
return nil, fmt.Errorf("server reply nonce was not %v bytes", responceNonceLength)
}
if !bytes.HasPrefix(sm.Nonce.Data, ac.nonce) {
return nil, errors.New("server nonce did not extend client nonce")
}
region, err := getRegion(sm.Host)
if err != nil {
return nil, err
}
creds, err := ac.credentials.GetWithContext(context.Background())
if err != nil {
return nil, err
}
currentTime := time.Now().UTC()
body := "Action=GetCallerIdentity&Version=2011-06-15"
// Create http.Request
req, _ := http.NewRequest("POST", "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Length", "43")
req.Host = sm.Host
req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
if len(creds.SessionToken) > 0 {
req.Header.Set("X-Amz-Security-Token", creds.SessionToken)
}
req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
// Create signer with credentials
signer := v4signer.NewSigner(ac.credentials)
// Get signed header
_, err = signer.Sign(req, strings.NewReader(body), "sts", region, currentTime)
if err != nil {
return nil, err
}
// create message
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
if len(creds.SessionToken) > 0 {
msg = bsoncore.AppendStringElement(msg, "t", creds.SessionToken)
}
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg, nil
}

View File

@@ -0,0 +1,32 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
)
// SpeculativeConversation represents an authentication conversation that can be merged with the initial connection
// handshake.
//
// FirstMessage method returns the first message to be sent to the server. This message will be included in the initial
// hello command.
//
// Finish takes the server response to the initial message and conducts the remainder of the conversation to
// authenticate the provided connection.
type SpeculativeConversation interface {
FirstMessage() (bsoncore.Document, error)
Finish(ctx context.Context, cfg *driver.AuthConfig, firstResponse bsoncore.Document) error
}
// SpeculativeAuthenticator represents an authenticator that supports speculative authentication.
type SpeculativeAuthenticator interface {
CreateSpeculativeConversation() (SpeculativeConversation, error)
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
)
// Cred is the type of user credential
type Cred = driver.Cred

View File

@@ -0,0 +1,58 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package creds
import (
"context"
"net/http"
"time"
"go.mongodb.org/mongo-driver/v2/internal/aws/credentials"
"go.mongodb.org/mongo-driver/v2/internal/credproviders"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
const (
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
//
// Set an early expiration of 5 minutes before the credentials are actually expired.
expiryWindow = 5 * time.Minute
)
// AWSCredentialProvider wraps AWS credentials.
type AWSCredentialProvider struct {
Cred *credentials.Credentials
}
// NewAWSCredentialProvider generates new AWSCredentialProvider
func NewAWSCredentialProvider(httpClient *http.Client, providers ...credentials.Provider) AWSCredentialProvider {
providers = append(
providers,
credproviders.NewEnvProvider(),
credproviders.NewAssumeRoleProvider(httpClient, expiryWindow),
credproviders.NewECSProvider(httpClient, expiryWindow),
credproviders.NewEC2Provider(httpClient, expiryWindow),
)
return AWSCredentialProvider{credentials.NewChainCredentials(providers)}
}
// GetCredentialsDoc generates AWS credentials.
func (p AWSCredentialProvider) GetCredentialsDoc(ctx context.Context) (bsoncore.Document, error) {
creds, err := p.Cred.GetWithContext(ctx)
if err != nil {
return nil, err
}
builder := bsoncore.NewDocumentBuilder().
AppendString("accessKeyId", creds.AccessKeyID).
AppendString("secretAccessKey", creds.SecretAccessKey)
if token := creds.SessionToken; len(token) > 0 {
builder.AppendString("sessionToken", token)
}
return builder.Build(), nil
}

View File

@@ -0,0 +1,40 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package creds
import (
"context"
"net/http"
"time"
"go.mongodb.org/mongo-driver/v2/internal/aws/credentials"
"go.mongodb.org/mongo-driver/v2/internal/credproviders"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// AzureCredentialProvider provides Azure credentials.
type AzureCredentialProvider struct {
cred *credentials.Credentials
}
// NewAzureCredentialProvider generates new AzureCredentialProvider
func NewAzureCredentialProvider(httpClient *http.Client) AzureCredentialProvider {
return AzureCredentialProvider{
credentials.NewCredentials(credproviders.NewAzureProvider(httpClient, 1*time.Minute)),
}
}
// GetCredentialsDoc generates Azure credentials.
func (p AzureCredentialProvider) GetCredentialsDoc(ctx context.Context) (bsoncore.Document, error) {
creds, err := p.cred.GetWithContext(ctx)
if err != nil {
return nil, err
}
builder := bsoncore.NewDocumentBuilder().
AppendString("accessToken", creds.SessionToken)
return builder.Build(), nil
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package creds is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package creds

View File

@@ -0,0 +1,74 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package creds
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// GCPCredentialProvider provides GCP credentials.
type GCPCredentialProvider struct {
httpClient *http.Client
}
// NewGCPCredentialProvider generates new GCPCredentialProvider
func NewGCPCredentialProvider(httpClient *http.Client) GCPCredentialProvider {
return GCPCredentialProvider{httpClient}
}
// GetCredentialsDoc generates GCP credentials.
func (p GCPCredentialProvider) GetCredentialsDoc(ctx context.Context) (bsoncore.Document, error) {
metadataHost := "metadata.google.internal"
if envhost := os.Getenv("GCE_METADATA_HOST"); envhost != "" {
metadataHost = envhost
}
url := fmt.Sprintf("http://%s/computeMetadata/v1/instance/service-accounts/default/token", metadataHost)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("unable to retrieve GCP credentials: %w", err)
}
req.Header.Set("Metadata-Flavor", "Google")
resp, err := p.httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("unable to retrieve GCP credentials: %w", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to retrieve GCP credentials: error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"unable to retrieve GCP credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s",
resp.StatusCode,
body)
}
var tokenResponse struct {
AccessToken string `json:"access_token"`
}
// Attempt to read body as JSON
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return nil, fmt.Errorf(
"unable to retrieve GCP credentials: error reading body JSON: %w (response body: %s)",
err,
body)
}
if tokenResponse.AccessToken == "" {
return nil, fmt.Errorf("unable to retrieve GCP credentials: got unexpected empty accessToken from GCP Metadata Server. Response body: %s", body)
}
builder := bsoncore.NewDocumentBuilder().AppendString("accessToken", tokenResponse.AccessToken)
return builder.Build(), nil
}

View File

@@ -0,0 +1,80 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"net/http"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
)
func newDefaultAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
scram, err := newScramSHA256Authenticator(cred, httpClient)
if err != nil {
return nil, newAuthError("failed to create internal authenticator", err)
}
speculative, ok := scram.(SpeculativeAuthenticator)
if !ok {
typeErr := fmt.Errorf("expected SCRAM authenticator to be SpeculativeAuthenticator but got %T", scram)
return nil, newAuthError("failed to create internal authenticator", typeErr)
}
return &DefaultAuthenticator{
Cred: cred,
speculativeAuthenticator: speculative,
httpClient: httpClient,
}, nil
}
// DefaultAuthenticator uses SCRAM-SHA-1 or SCRAM-SHA-256, depending on the
// server's SASL supported mechanisms.
type DefaultAuthenticator struct {
Cred *Cred
// The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing
// the initial hello, SCRAM-SHA-256 is used for the speculative attempt.
speculativeAuthenticator SpeculativeAuthenticator
httpClient *http.Client
}
var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil)
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
func (a *DefaultAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return a.speculativeAuthenticator.CreateSpeculativeConversation()
}
// Auth authenticates the connection.
func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
actual, err := func() (Authenticator, error) {
// If a server provides a list of supported mechanisms, we choose
// SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1.
// Otherwise, we decide based on what is supported.
if saslSupportedMechs := cfg.HandshakeInfo.SaslSupportedMechs; saslSupportedMechs != nil {
for _, v := range saslSupportedMechs {
if v == SCRAMSHA256 {
return newScramSHA256Authenticator(a.Cred, a.httpClient)
}
}
}
return newScramSHA1Authenticator(a.Cred, a.httpClient)
}()
if err != nil {
return newAuthError("error creating authenticator", err)
}
return actual.Auth(ctx, cfg)
}
// Reauth reauthenticates the connection.
func (a *DefaultAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error {
return newAuthError("DefaultAuthenticator does not support reauthentication", nil)
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package auth is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package auth

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && (windows || linux || darwin)
package auth
import (
"context"
"fmt"
"net"
"net/http"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth/internal/gssapi"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
}
return &GSSAPIAuthenticator{
Username: cred.Username,
Password: cred.Password,
PasswordSet: cred.PasswordSet,
Props: cred.Props,
}, nil
}
// GSSAPIAuthenticator uses the GSSAPI algorithm over SASL to authenticate a connection.
type GSSAPIAuthenticator struct {
Username string
Password string
PasswordSet bool
Props map[string]string
}
// Auth authenticates the connection.
func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
target := cfg.Connection.Description().Addr.String()
hostname, _, err := net.SplitHostPort(target)
if err != nil {
return newAuthError(fmt.Sprintf("invalid endpoint (%s) specified: %s", target, err), nil)
}
client, err := gssapi.New(hostname, a.Username, a.Password, a.PasswordSet, a.Props)
if err != nil {
return newAuthError("error creating gssapi", err)
}
return ConductSaslConversation(ctx, cfg, sourceExternal, client)
}
// Reauth reauthenticates the connection.
func (a *GSSAPIAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error {
return newAuthError("GSSAPI does not support reauthentication", nil)
}

View File

@@ -0,0 +1,18 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !gssapi
package auth
import "net/http"
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) {
return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil)
}

View File

@@ -0,0 +1,22 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && !windows && !linux && !darwin
package auth
import (
"fmt"
"net/http"
"runtime"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) {
return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil)
}

View File

@@ -0,0 +1,168 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && (linux || darwin)
// +build gssapi
// +build linux darwin
package gssapi
/*
#cgo linux CFLAGS: -DGOOS_linux
#cgo linux LDFLAGS: -lgssapi_krb5 -lkrb5
#cgo darwin CFLAGS: -DGOOS_darwin
#cgo darwin LDFLAGS: -framework GSS
#include "gss_wrapper.h"
*/
import "C"
import (
"context"
"fmt"
"runtime"
"strings"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
serviceName := "mongodb"
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_REALM":
return nil, fmt.Errorf("SERVICE_REALM is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
target = value
default:
return nil, fmt.Errorf("unknown mechanism property %s", key)
}
}
servicePrincipalName := fmt.Sprintf("%s@%s", serviceName, target)
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.gssapi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.gssapi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.gssapi_client_init(&sc.state, cservicePrincipalName, cusername, cpassword)
if status != C.GSSAPI_OK {
return mechName, nil, sc.getError("unable to initialize client")
}
payload, err := sc.Next(nil, nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(_ context.Context, challenge []byte) ([]byte, error) {
var buf unsafe.Pointer
var bufLen C.size_t
var outBuf unsafe.Pointer
var outBufLen C.size_t
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.gssapi_client_username(&sc.state, &cusername)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf = unsafe.Pointer(&bytes[0])
bufLen = C.size_t(len(bytes))
status := C.gssapi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
if len(challenge) > 0 {
buf = unsafe.Pointer(&challenge[0])
bufLen = C.size_t(len(challenge))
}
status := C.gssapi_client_negotiate(&sc.state, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.GSSAPI_OK:
sc.contextComplete = true
case C.GSSAPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != nil {
defer C.free(outBuf)
}
return C.GoBytes(outBuf, C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
var desc *C.char
status := C.gssapi_error_desc(sc.state.maj_stat, sc.state.min_stat, &desc)
if status != C.GSSAPI_OK {
if desc != nil {
C.free(unsafe.Pointer(desc))
}
return fmt.Errorf("%s: (%v, %v)", prefix, sc.state.maj_stat, sc.state.min_stat)
}
defer C.free(unsafe.Pointer(desc))
return fmt.Errorf("%s: %v(%v,%v)", prefix, C.GoString(desc), int32(sc.state.maj_stat), int32(sc.state.min_stat))
}

View File

@@ -0,0 +1,254 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build linux darwin
#include <string.h>
#include <stdio.h>
#include "gss_wrapper.h"
OM_uint32 gssapi_canonicalize_name(
OM_uint32* minor_status,
char *input_name,
gss_OID input_name_type,
gss_name_t *output_name
)
{
OM_uint32 major_status;
gss_name_t imported_name = GSS_C_NO_NAME;
gss_buffer_desc buffer = GSS_C_EMPTY_BUFFER;
buffer.value = input_name;
buffer.length = strlen(input_name);
major_status = gss_import_name(minor_status, &buffer, input_name_type, &imported_name);
if (GSS_ERROR(major_status)) {
return major_status;
}
major_status = gss_canonicalize_name(minor_status, imported_name, (gss_OID)gss_mech_krb5, output_name);
if (imported_name != GSS_C_NO_NAME) {
OM_uint32 ignored;
gss_release_name(&ignored, &imported_name);
}
return major_status;
}
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
)
{
OM_uint32 stat = maj_stat;
int stat_type = GSS_C_GSS_CODE;
if (min_stat != 0) {
stat = min_stat;
stat_type = GSS_C_MECH_CODE;
}
OM_uint32 local_maj_stat, local_min_stat;
OM_uint32 msg_ctx = 0;
gss_buffer_desc desc_buffer;
do
{
local_maj_stat = gss_display_status(
&local_min_stat,
stat,
stat_type,
GSS_C_NO_OID,
&msg_ctx,
&desc_buffer
);
if (GSS_ERROR(local_maj_stat)) {
return GSSAPI_ERROR;
}
if (*desc) {
free(*desc);
}
*desc = calloc(1, desc_buffer.length + 1);
memcpy(*desc, desc_buffer.value, desc_buffer.length);
gss_release_buffer(&local_min_stat, &desc_buffer);
}
while(msg_ctx != 0);
return GSSAPI_OK;
}
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
)
{
client->cred = GSS_C_NO_CREDENTIAL;
client->ctx = GSS_C_NO_CONTEXT;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, spn, GSS_C_NT_HOSTBASED_SERVICE, &client->spn);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (username) {
gss_name_t name;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, username, GSS_C_NT_USER_NAME, &name);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (password) {
gss_buffer_desc password_buffer;
password_buffer.value = password;
password_buffer.length = strlen(password);
client->maj_stat = gss_acquire_cred_with_password(&client->min_stat, name, &password_buffer, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
} else {
client->maj_stat = gss_acquire_cred(&client->min_stat, name, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
OM_uint32 ignored;
gss_release_name(&ignored, &name);
}
return GSSAPI_OK;
}
int gssapi_client_username(
gssapi_client_state *client,
char** username
)
{
OM_uint32 ignored;
gss_name_t name = GSS_C_NO_NAME;
client->maj_stat = gss_inquire_context(&client->min_stat, client->ctx, &name, NULL, NULL, NULL, NULL, NULL, NULL);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
gss_buffer_desc name_buffer;
client->maj_stat = gss_display_name(&client->min_stat, name, &name_buffer, NULL);
if (GSS_ERROR(client->maj_stat)) {
gss_release_name(&ignored, &name);
return GSSAPI_ERROR;
}
*username = calloc(1, name_buffer.length + 1);
memcpy(*username, name_buffer.value, name_buffer.length);
gss_release_buffer(&ignored, &name_buffer);
gss_release_name(&ignored, &name);
return GSSAPI_OK;
}
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
if (input) {
input_buffer.value = input;
input_buffer.length = input_length;
}
client->maj_stat = gss_init_sec_context(
&client->min_stat,
client->cred,
&client->ctx,
client->spn,
GSS_C_NO_OID,
GSS_C_MUTUAL_FLAG | GSS_C_SEQUENCE_FLAG,
0,
GSS_C_NO_CHANNEL_BINDINGS,
&input_buffer,
NULL,
&output_buffer,
NULL,
NULL
);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
OM_uint32 ignored;
gss_release_buffer(&ignored, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
} else if (client->maj_stat == GSS_S_CONTINUE_NEEDED) {
return GSSAPI_CONTINUE;
}
return GSSAPI_OK;
}
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
input_buffer.value = input;
input_buffer.length = input_length;
client->maj_stat = gss_wrap(&client->min_stat, client->ctx, 0, GSS_C_QOP_DEFAULT, &input_buffer, NULL, &output_buffer);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
gss_release_buffer(&client->min_stat, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
return GSSAPI_OK;
}
int gssapi_client_destroy(
gssapi_client_state *client
)
{
OM_uint32 ignored;
if (client->ctx != GSS_C_NO_CONTEXT) {
gss_delete_sec_context(&ignored, &client->ctx, GSS_C_NO_BUFFER);
}
if (client->spn != GSS_C_NO_NAME) {
gss_release_name(&ignored, &client->spn);
}
if (client->cred != GSS_C_NO_CREDENTIAL) {
gss_release_cred(&ignored, &client->cred);
}
return GSSAPI_OK;
}

View File

@@ -0,0 +1,72 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build linux darwin
#ifndef GSS_WRAPPER_H
#define GSS_WRAPPER_H
#include <stdlib.h>
#ifdef GOOS_linux
#include <gssapi/gssapi.h>
#include <gssapi/gssapi_krb5.h>
#endif
#ifdef GOOS_darwin
#include <GSS/GSS.h>
#endif
#define GSSAPI_OK 0
#define GSSAPI_CONTINUE 1
#define GSSAPI_ERROR 2
typedef struct {
gss_name_t spn;
gss_cred_id_t cred;
gss_ctx_id_t ctx;
OM_uint32 maj_stat;
OM_uint32 min_stat;
} gssapi_client_state;
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
);
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
);
int gssapi_client_username(
gssapi_client_state *client,
char** username
);
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_destroy(
gssapi_client_state *client
);
#endif

View File

@@ -0,0 +1,356 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && windows
// +build gssapi,windows
package gssapi
// #include "sspi_wrapper.h"
import "C"
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
initOnce.Do(initSSPI)
if initError != nil {
return nil, initError
}
var err error
serviceName := "mongodb"
serviceRealm := ""
canonicalizeHostName := false
var serviceHostSet bool
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
canonicalizeHostName, err = strconv.ParseBool(value)
if err != nil {
return nil, fmt.Errorf("%s must be a boolean (true, false, 0, 1) but got '%s'", key, value)
}
case "SERVICE_REALM":
serviceRealm = value
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
serviceHostSet = true
target = value
}
}
if canonicalizeHostName {
// Should not canonicalize the SERVICE_HOST
if serviceHostSet {
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME and SERVICE_HOST canonot both be specified")
}
names, err := net.LookupAddr(target)
if err != nil || len(names) == 0 {
return nil, fmt.Errorf("unable to canonicalize hostname: %s", err)
}
target = names[0]
if target[len(target)-1] == '.' {
target = target[:len(target)-1]
}
}
servicePrincipalName := fmt.Sprintf("%s/%s", serviceName, target)
if serviceRealm != "" {
servicePrincipalName += "@" + serviceRealm
}
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.sspi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.sspi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.sspi_client_init(&sc.state, cusername, cpassword)
if status != C.SSPI_OK {
return mechName, nil, sc.getError("unable to initialize client")
}
payload, err := sc.Next(nil, nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(_ context.Context, challenge []byte) ([]byte, error) {
var outBuf C.PVOID
var outBufLen C.ULONG
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.sspi_client_username(&sc.state, &cusername)
if status != C.SSPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf := (C.PVOID)(unsafe.Pointer(&bytes[0]))
bufLen := C.ULONG(len(bytes))
status := C.sspi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.SSPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
var buf C.PVOID
var bufLen C.ULONG
if len(challenge) > 0 {
buf = (C.PVOID)(unsafe.Pointer(&challenge[0]))
bufLen = C.ULONG(len(challenge))
}
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
status := C.sspi_client_negotiate(&sc.state, cservicePrincipalName, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.SSPI_OK:
sc.contextComplete = true
case C.SSPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != C.PVOID(nil) {
defer C.free(unsafe.Pointer(outBuf))
}
return C.GoBytes(unsafe.Pointer(outBuf), C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
return getError(prefix, sc.state.status)
}
var (
initOnce sync.Once
initError error
)
func initSSPI() {
rc := C.sspi_init()
if rc != 0 {
initError = fmt.Errorf("error initializing sspi: %v", rc)
}
}
func getError(prefix string, status C.SECURITY_STATUS) error {
var s string
switch status {
case C.SEC_E_ALGORITHM_MISMATCH:
s = "The client and server cannot communicate because they do not possess a common algorithm."
case C.SEC_E_BAD_BINDINGS:
s = "The SSPI channel bindings supplied by the client are incorrect."
case C.SEC_E_BAD_PKGID:
s = "The requested package identifier does not exist."
case C.SEC_E_BUFFER_TOO_SMALL:
s = "The buffers supplied to the function are not large enough to contain the information."
case C.SEC_E_CANNOT_INSTALL:
s = "The security package cannot initialize successfully and should not be installed."
case C.SEC_E_CANNOT_PACK:
s = "The package is unable to pack the context."
case C.SEC_E_CERT_EXPIRED:
s = "The received certificate has expired."
case C.SEC_E_CERT_UNKNOWN:
s = "An unknown error occurred while processing the certificate."
case C.SEC_E_CERT_WRONG_USAGE:
s = "The certificate is not valid for the requested usage."
case C.SEC_E_CONTEXT_EXPIRED:
s = "The application is referencing a context that has already been closed. A properly written application should not receive this error."
case C.SEC_E_CROSSREALM_DELEGATION_FAILURE:
s = "The server attempted to make a Kerberos-constrained delegation request for a target outside the server's realm."
case C.SEC_E_CRYPTO_SYSTEM_INVALID:
s = "The cryptographic system or checksum function is not valid because a required function is unavailable."
case C.SEC_E_DECRYPT_FAILURE:
s = "The specified data could not be decrypted."
case C.SEC_E_DELEGATION_REQUIRED:
s = "The requested operation cannot be completed. The computer must be trusted for delegation"
case C.SEC_E_DOWNGRADE_DETECTED:
s = "The system detected a possible attempt to compromise security. Verify that the server that authenticated you can be contacted."
case C.SEC_E_ENCRYPT_FAILURE:
s = "The specified data could not be encrypted."
case C.SEC_E_ILLEGAL_MESSAGE:
s = "The message received was unexpected or badly formatted."
case C.SEC_E_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. The context could not be initialized."
case C.SEC_E_INCOMPLETE_MESSAGE:
s = "The message supplied was incomplete. The signature was not verified."
case C.SEC_E_INSUFFICIENT_MEMORY:
s = "Not enough memory is available to complete the request."
case C.SEC_E_INTERNAL_ERROR:
s = "An error occurred that did not map to an SSPI error code."
case C.SEC_E_INVALID_HANDLE:
s = "The handle passed to the function is not valid."
case C.SEC_E_INVALID_TOKEN:
s = "The token passed to the function is not valid."
case C.SEC_E_ISSUING_CA_UNTRUSTED:
s = "An untrusted certification authority (CA) was detected while processing the smart card certificate used for authentication."
case C.SEC_E_ISSUING_CA_UNTRUSTED_KDC:
s = "An untrusted CA was detected while processing the domain controller certificate used for authentication. The system event log contains additional information."
case C.SEC_E_KDC_CERT_EXPIRED:
s = "The domain controller certificate used for smart card logon has expired."
case C.SEC_E_KDC_CERT_REVOKED:
s = "The domain controller certificate used for smart card logon has been revoked."
case C.SEC_E_KDC_INVALID_REQUEST:
s = "A request that is not valid was sent to the KDC."
case C.SEC_E_KDC_UNABLE_TO_REFER:
s = "The KDC was unable to generate a referral for the service requested."
case C.SEC_E_KDC_UNKNOWN_ETYPE:
s = "The requested encryption type is not supported by the KDC."
case C.SEC_E_LOGON_DENIED:
s = "The logon has been denied"
case C.SEC_E_MAX_REFERRALS_EXCEEDED:
s = "The number of maximum ticket referrals has been exceeded."
case C.SEC_E_MESSAGE_ALTERED:
s = "The message supplied for verification has been altered."
case C.SEC_E_MULTIPLE_ACCOUNTS:
s = "The received certificate was mapped to multiple accounts."
case C.SEC_E_MUST_BE_KDC:
s = "The local computer must be a Kerberos domain controller (KDC)"
case C.SEC_E_NO_AUTHENTICATING_AUTHORITY:
s = "No authority could be contacted for authentication."
case C.SEC_E_NO_CREDENTIALS:
s = "No credentials are available."
case C.SEC_E_NO_IMPERSONATION:
s = "No impersonation is allowed for this context."
case C.SEC_E_NO_IP_ADDRESSES:
s = "Unable to accomplish the requested task because the local computer does not have any IP addresses."
case C.SEC_E_NO_KERB_KEY:
s = "No Kerberos key was found."
case C.SEC_E_NO_PA_DATA:
s = "Policy administrator (PA) data is needed to determine the encryption type"
case C.SEC_E_NO_S4U_PROT_SUPPORT:
s = "The Kerberos subsystem encountered an error. A service for user protocol request was made against a domain controller which does not support service for a user."
case C.SEC_E_NO_TGT_REPLY:
s = "The client is trying to negotiate a context and the server requires a user-to-user connection"
case C.SEC_E_NOT_OWNER:
s = "The caller of the function does not own the credentials."
case C.SEC_E_OK:
s = "The operation completed successfully."
case C.SEC_E_OUT_OF_SEQUENCE:
s = "The message supplied for verification is out of sequence."
case C.SEC_E_PKINIT_CLIENT_FAILURE:
s = "The smart card certificate used for authentication is not trusted."
case C.SEC_E_PKINIT_NAME_MISMATCH:
s = "The client certificate does not contain a valid UPN or does not match the client name in the logon request."
case C.SEC_E_QOP_NOT_SUPPORTED:
s = "The quality of protection attribute is not supported by this package."
case C.SEC_E_REVOCATION_OFFLINE_C:
s = "The revocation status of the smart card certificate used for authentication could not be determined."
case C.SEC_E_REVOCATION_OFFLINE_KDC:
s = "The revocation status of the domain controller certificate used for smart card authentication could not be determined. The system event log contains additional information."
case C.SEC_E_SECPKG_NOT_FOUND:
s = "The security package was not recognized."
case C.SEC_E_SECURITY_QOS_FAILED:
s = "The security context could not be established due to a failure in the requested quality of service (for example"
case C.SEC_E_SHUTDOWN_IN_PROGRESS:
s = "A system shutdown is in progress."
case C.SEC_E_SMARTCARD_CERT_EXPIRED:
s = "The smart card certificate used for authentication has expired."
case C.SEC_E_SMARTCARD_CERT_REVOKED:
s = "The smart card certificate used for authentication has been revoked. Additional information may exist in the event log."
case C.SEC_E_SMARTCARD_LOGON_REQUIRED:
s = "Smart card logon is required and was not used."
case C.SEC_E_STRONG_CRYPTO_NOT_SUPPORTED:
s = "The other end of the security negotiation requires strong cryptography"
case C.SEC_E_TARGET_UNKNOWN:
s = "The target was not recognized."
case C.SEC_E_TIME_SKEW:
s = "The clocks on the client and server computers do not match."
case C.SEC_E_TOO_MANY_PRINCIPALS:
s = "The KDC reply contained more than one principal name."
case C.SEC_E_UNFINISHED_CONTEXT_DELETED:
s = "A security context was deleted before the context was completed. This is considered a logon failure."
case C.SEC_E_UNKNOWN_CREDENTIALS:
s = "The credentials provided were not recognized."
case C.SEC_E_UNSUPPORTED_FUNCTION:
s = "The requested function is not supported."
case C.SEC_E_UNSUPPORTED_PREAUTH:
s = "An unsupported preauthentication mechanism was presented to the Kerberos package."
case C.SEC_E_UNTRUSTED_ROOT:
s = "The certificate chain was issued by an authority that is not trusted."
case C.SEC_E_WRONG_CREDENTIAL_HANDLE:
s = "The supplied credential handle does not match the credential associated with the security context."
case C.SEC_E_WRONG_PRINCIPAL:
s = "The target principal name is incorrect."
case C.SEC_I_COMPLETE_AND_CONTINUE:
s = "The function completed successfully"
case C.SEC_I_COMPLETE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_CONTEXT_EXPIRED:
s = "The message sender has finished using the connection and has initiated a shutdown. For information about initiating or recognizing a shutdown"
case C.SEC_I_CONTINUE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. Additional information can be returned from the context."
case C.SEC_I_LOCAL_LOGON:
s = "The logon was completed"
case C.SEC_I_NO_LSA_CONTEXT:
s = "There is no LSA mode context associated with this context."
case C.SEC_I_RENEGOTIATE:
s = "The context data must be renegotiated with the peer."
default:
return fmt.Errorf("%s: 0x%x", prefix, uint32(status))
}
return fmt.Errorf("%s: %s(0x%x)", prefix, s, uint32(status))
}

View File

@@ -0,0 +1,249 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,windows
#include "sspi_wrapper.h"
static HINSTANCE sspi_secur32_dll = NULL;
static PSecurityFunctionTable sspi_functions = NULL;
static const LPSTR SSPI_PACKAGE_NAME = "kerberos";
int sspi_init(
)
{
// Load the secur32.dll library using its exact path. Passing the exact DLL path rather than allowing LoadLibrary to
// search in different locations removes the possibility of DLL preloading attacks. We use GetSystemDirectoryA and
// LoadLibraryA rather than the GetSystemDirectory/LoadLibrary aliases to ensure the ANSI versions are used so we
// don't have to account for variations in char sizes if UNICODE is enabled.
// Passing a 0 size will return the required buffer length to hold the path, including the null terminator.
int requiredLen = GetSystemDirectoryA(NULL, 0);
if (!requiredLen) {
return GetLastError();
}
// Allocate a buffer to hold the system directory + "\secur32.dll" (length 12, not including null terminator).
int actualLen = requiredLen + 12;
char *directoryBuffer = (char *) calloc(1, actualLen);
int directoryLen = GetSystemDirectoryA(directoryBuffer, actualLen);
if (!directoryLen) {
free(directoryBuffer);
return GetLastError();
}
// Append the DLL name to the buffer.
char *dllName = "\\secur32.dll";
strcpy_s(&(directoryBuffer[directoryLen]), actualLen - directoryLen, dllName);
sspi_secur32_dll = LoadLibraryA(directoryBuffer);
free(directoryBuffer);
if (!sspi_secur32_dll) {
return GetLastError();
}
INIT_SECURITY_INTERFACE init_security_interface = (INIT_SECURITY_INTERFACE)GetProcAddress(sspi_secur32_dll, SECURITY_ENTRYPOINT);
if (!init_security_interface) {
return -1;
}
sspi_functions = (*init_security_interface)();
if (!sspi_functions) {
return -2;
}
return SSPI_OK;
}
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
)
{
TimeStamp timestamp;
if (username) {
if (password) {
SEC_WINNT_AUTH_IDENTITY auth_identity;
#ifdef _UNICODE
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
#else
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
#endif
auth_identity.User = (LPSTR) username;
auth_identity.UserLength = strlen(username);
auth_identity.Password = (LPSTR) password;
auth_identity.PasswordLength = strlen(password);
auth_identity.Domain = NULL;
auth_identity.DomainLength = 0;
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, &client->cred, &timestamp);
} else {
client->status = sspi_functions->AcquireCredentialsHandle(username, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
} else {
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
return SSPI_OK;
}
int sspi_client_username(
sspi_client_state *client,
char** username
)
{
SecPkgCredentials_Names names;
client->status = sspi_functions->QueryCredentialsAttributes(&client->cred, SECPKG_CRED_ATTR_NAMES, &names);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
int len = strlen(names.sUserName) + 1;
*username = malloc(len);
memcpy(*username, names.sUserName, len);
sspi_functions->FreeContextBuffer(names.sUserName);
return SSPI_OK;
}
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecBufferDesc inbuf;
SecBuffer in_bufs[1];
SecBufferDesc outbuf;
SecBuffer out_bufs[1];
if (client->has_ctx > 0) {
inbuf.ulVersion = SECBUFFER_VERSION;
inbuf.cBuffers = 1;
inbuf.pBuffers = in_bufs;
in_bufs[0].pvBuffer = input;
in_bufs[0].cbBuffer = input_length;
in_bufs[0].BufferType = SECBUFFER_TOKEN;
}
outbuf.ulVersion = SECBUFFER_VERSION;
outbuf.cBuffers = 1;
outbuf.pBuffers = out_bufs;
out_bufs[0].pvBuffer = NULL;
out_bufs[0].cbBuffer = 0;
out_bufs[0].BufferType = SECBUFFER_TOKEN;
ULONG context_attr = 0;
client->status = sspi_functions->InitializeSecurityContext(
&client->cred,
client->has_ctx > 0 ? &client->ctx : NULL,
(LPSTR) spn,
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH,
0,
SECURITY_NETWORK_DREP,
client->has_ctx > 0 ? &inbuf : NULL,
0,
&client->ctx,
&outbuf,
&context_attr,
NULL);
if (client->status != SEC_E_OK && client->status != SEC_I_CONTINUE_NEEDED) {
return SSPI_ERROR;
}
client->has_ctx = 1;
*output = malloc(out_bufs[0].cbBuffer);
*output_length = out_bufs[0].cbBuffer;
memcpy(*output, out_bufs[0].pvBuffer, *output_length);
sspi_functions->FreeContextBuffer(out_bufs[0].pvBuffer);
if (client->status == SEC_I_CONTINUE_NEEDED) {
return SSPI_CONTINUE;
}
return SSPI_OK;
}
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecPkgContext_Sizes sizes;
client->status = sspi_functions->QueryContextAttributes(&client->ctx, SECPKG_ATTR_SIZES, &sizes);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
char *msg = malloc((sizes.cbSecurityTrailer + input_length + sizes.cbBlockSize) * sizeof(char));
memcpy(&msg[sizes.cbSecurityTrailer], input, input_length);
SecBuffer wrap_bufs[3];
SecBufferDesc wrap_buf_desc;
wrap_buf_desc.cBuffers = 3;
wrap_buf_desc.pBuffers = wrap_bufs;
wrap_buf_desc.ulVersion = SECBUFFER_VERSION;
wrap_bufs[0].cbBuffer = sizes.cbSecurityTrailer;
wrap_bufs[0].BufferType = SECBUFFER_TOKEN;
wrap_bufs[0].pvBuffer = msg;
wrap_bufs[1].cbBuffer = input_length;
wrap_bufs[1].BufferType = SECBUFFER_DATA;
wrap_bufs[1].pvBuffer = msg + sizes.cbSecurityTrailer;
wrap_bufs[2].cbBuffer = sizes.cbBlockSize;
wrap_bufs[2].BufferType = SECBUFFER_PADDING;
wrap_bufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + input_length;
client->status = sspi_functions->EncryptMessage(&client->ctx, SECQOP_WRAP_NO_ENCRYPT, &wrap_buf_desc, 0);
if (client->status != SEC_E_OK) {
free(msg);
return SSPI_ERROR;
}
*output_length = wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer + wrap_bufs[2].cbBuffer;
*output = malloc(*output_length);
memcpy(*output, wrap_bufs[0].pvBuffer, wrap_bufs[0].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer, wrap_bufs[1].pvBuffer, wrap_bufs[1].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer, wrap_bufs[2].pvBuffer, wrap_bufs[2].cbBuffer);
free(msg);
return SSPI_OK;
}
int sspi_client_destroy(
sspi_client_state *client
)
{
if (client->has_ctx > 0) {
sspi_functions->DeleteSecurityContext(&client->ctx);
}
sspi_functions->FreeCredentialsHandle(&client->cred);
return SSPI_OK;
}

View File

@@ -0,0 +1,64 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,windows
#ifndef SSPI_WRAPPER_H
#define SSPI_WRAPPER_H
#define SECURITY_WIN32 1 /* Required for SSPI */
#include <windows.h>
#include <sspi.h>
#define SSPI_OK 0
#define SSPI_CONTINUE 1
#define SSPI_ERROR 2
typedef struct {
CredHandle cred;
CtxtHandle ctx;
int has_ctx;
SECURITY_STATUS status;
} sspi_client_state;
int sspi_init();
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
);
int sspi_client_username(
sspi_client_state *client,
char** username
);
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_destroy(
sspi_client_state *client
);
#endif

View File

@@ -0,0 +1,92 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"errors"
"net/http"
"go.mongodb.org/mongo-driver/v2/internal/aws/credentials"
"go.mongodb.org/mongo-driver/v2/internal/credproviders"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth/creds"
)
// MongoDBAWS is the mechanism name for MongoDBAWS.
const MongoDBAWS = "MONGODB-AWS"
func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil)
}
if httpClient == nil {
return nil, errors.New("httpClient must not be nil")
}
return &MongoDBAWSAuthenticator{
credentials: &credproviders.StaticProvider{
Value: credentials.Value{
AccessKeyID: cred.Username,
SecretAccessKey: cred.Password,
SessionToken: cred.Props["AWS_SESSION_TOKEN"],
},
},
httpClient: httpClient,
}, nil
}
// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
type MongoDBAWSAuthenticator struct {
credentials *credproviders.StaticProvider
httpClient *http.Client
}
// Auth authenticates the connection.
func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
providers := creds.NewAWSCredentialProvider(a.httpClient, a.credentials)
adapter := &awsSaslAdapter{
conversation: &awsConversation{
credentials: providers.Cred,
},
}
err := ConductSaslConversation(ctx, cfg, sourceExternal, adapter)
if err != nil {
return newAuthError("sasl conversation error", err)
}
return nil
}
// Reauth reauthenticates the connection.
func (a *MongoDBAWSAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error {
return newAuthError("AWS authentication does not support reauthentication", nil)
}
type awsSaslAdapter struct {
conversation *awsConversation
}
var _ SaslClient = (*awsSaslAdapter)(nil)
func (a *awsSaslAdapter) Start() (string, []byte, error) {
step, err := a.conversation.Step(nil)
if err != nil {
return MongoDBAWS, nil, err
}
return MongoDBAWS, step, nil
}
func (a *awsSaslAdapter) Next(_ context.Context, challenge []byte) ([]byte, error) {
step, err := a.conversation.Step(challenge)
if err != nil {
return nil, err
}
return step, nil
}
func (a *awsSaslAdapter) Completed() bool {
return a.conversation.Done()
}

View File

@@ -0,0 +1,595 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"regexp"
"strings"
"sync"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
)
// MongoDBOIDC is the string constant for the MONGODB-OIDC authentication mechanism.
const MongoDBOIDC = "MONGODB-OIDC"
// Valid authMechanismProperties keys for MONGODB-OIDC.
const (
// EnvironmentProp is the property key name that specifies the environment for the OIDC authenticator.
EnvironmentProp = "ENVIRONMENT"
// ResourceProp is the property key name that specifies the token resource for GCP and AZURE OIDC auth.
ResourceProp = "TOKEN_RESOURCE"
// AllowedHostsProp is the property key name that specifies the allowed hosts for the OIDC authenticator.
AllowedHostsProp = "ALLOWED_HOSTS"
)
// Valid ENVIRONMENT authMechismProperty values for MONGODB-OIDC.
const (
// AzureEnvironmentValue is the value for the Azure environment.
AzureEnvironmentValue = "azure"
// GCPEnvironmentValue is the value for the GCP environment.
GCPEnvironmentValue = "gcp"
// K8SEnvironmentValue is the value for Kubernetes environments.
K8SEnvironmentValue = "k8s"
// TestEnvironmentValue is the value for the test environment.
TestEnvironmentValue = "test"
)
const (
apiVersion = 1
invalidateSleepTimeout = 100 * time.Millisecond
// The CSOT specification says to apply a 1-minute timeout if "CSOT is not applied". That's
// ambiguous for the v1.x Go Driver because it could mean either "no timeout provided" or "CSOT not
// enabled". Always use a maximum timeout duration of 1 minute, allowing us to ignore the ambiguity.
// Contexts with a shorter timeout are unaffected.
machineCallbackTimeout = time.Minute
humanCallbackTimeout = 5 * time.Minute
)
var defaultAllowedHosts = []*regexp.Regexp{
regexp.MustCompile(`^.*[.]mongodb[.]net(:\d+)?$`),
regexp.MustCompile(`^.*[.]mongodb-qa[.]net(:\d+)?$`),
regexp.MustCompile(`^.*[.]mongodb-dev[.]net(:\d+)?$`),
regexp.MustCompile(`^.*[.]mongodbgov[.]net(:\d+)?$`),
regexp.MustCompile(`^localhost(:\d+)?$`),
regexp.MustCompile(`^127[.]0[.]0[.]1(:\d+)?$`),
regexp.MustCompile(`^::1(:\d+)?$`),
regexp.MustCompile(`^.*[.]mongo[.]com(:\d+)?$`),
}
// OIDCCallback is a function that takes a context and OIDCArgs and returns an OIDCCredential.
type OIDCCallback = driver.OIDCCallback
// OIDCArgs contains the arguments for the OIDC callback.
type OIDCArgs = driver.OIDCArgs
// OIDCCredential contains the access token and refresh token.
type OIDCCredential = driver.OIDCCredential
// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider.
type IDPInfo = driver.IDPInfo
var (
_ driver.Authenticator = (*OIDCAuthenticator)(nil)
_ SpeculativeAuthenticator = (*OIDCAuthenticator)(nil)
_ SaslClient = (*oidcOneStep)(nil)
_ SaslClient = (*oidcTwoStep)(nil)
)
// OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken,
// and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality
// is only for the OIDC Human flow.
type OIDCAuthenticator struct {
mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct.
AuthMechanismProperties map[string]string
OIDCMachineCallback OIDCCallback
OIDCHumanCallback OIDCCallback
allowedHosts *[]*regexp.Regexp
userName string
httpClient *http.Client
accessToken string
refreshToken *string
idpInfo *IDPInfo
tokenGenID uint64
}
// SetAccessToken allows for manually setting the access token for the OIDCAuthenticator, this is
// only for testing purposes.
func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) {
oa.mu.Lock()
defer oa.mu.Unlock()
oa.accessToken = accessToken
}
func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("MONGODB-OIDC source must be empty or $external", nil)
}
if cred.Password != "" {
return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC)
}
if cred.Props != nil {
if env, ok := cred.Props[EnvironmentProp]; ok {
switch strings.ToLower(env) {
case AzureEnvironmentValue, GCPEnvironmentValue:
if _, ok := cred.Props[ResourceProp]; !ok {
return nil, fmt.Errorf("%q must be specified for %q %q", ResourceProp, env, EnvironmentProp)
}
fallthrough
case K8SEnvironmentValue, TestEnvironmentValue:
if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil {
return nil, fmt.Errorf("OIDC callbacks are not allowed for %q %q", env, EnvironmentProp)
}
}
}
}
oa := &OIDCAuthenticator{
userName: cred.Username,
httpClient: httpClient,
AuthMechanismProperties: cred.Props,
OIDCMachineCallback: cred.OIDCMachineCallback,
OIDCHumanCallback: cred.OIDCHumanCallback,
}
err := oa.setAllowedHosts()
return oa, err
}
func createPatternsForGlobs(hosts []string) ([]*regexp.Regexp, error) {
var err error
ret := make([]*regexp.Regexp, len(hosts))
for i := range hosts {
hosts[i] = strings.ReplaceAll(hosts[i], ".", "[.]")
hosts[i] = strings.ReplaceAll(hosts[i], "*", ".*")
hosts[i] = "^" + hosts[i] + "(:\\d+)?$"
ret[i], err = regexp.Compile(hosts[i])
if err != nil {
return nil, err
}
}
return ret, nil
}
func (oa *OIDCAuthenticator) setAllowedHosts() error {
if oa.AuthMechanismProperties == nil {
oa.allowedHosts = &defaultAllowedHosts
return nil
}
allowedHosts, ok := oa.AuthMechanismProperties[AllowedHostsProp]
if !ok {
oa.allowedHosts = &defaultAllowedHosts
return nil
}
globs := strings.Split(allowedHosts, ",")
ret, err := createPatternsForGlobs(globs)
if err != nil {
return err
}
oa.allowedHosts = &ret
return nil
}
func (oa *OIDCAuthenticator) validateConnectionAddressWithAllowedHosts(conn *mnet.Connection) error {
if oa.allowedHosts == nil {
// should be unreachable, but this is a safety check.
return newAuthError(fmt.Sprintf("%q missing", AllowedHostsProp), nil)
}
allowedHosts := *oa.allowedHosts
if len(allowedHosts) == 0 {
return newAuthError(fmt.Sprintf("empty %q specified", AllowedHostsProp), nil)
}
for _, pattern := range allowedHosts {
if pattern.MatchString(string(conn.Address())) {
return nil
}
}
return newAuthError(fmt.Sprintf("address %q not allowed by %q: %v", conn.Address(), AllowedHostsProp, allowedHosts), nil)
}
type oidcOneStep struct {
userName string
accessToken string
}
type oidcTwoStep struct {
conn *mnet.Connection
oa *OIDCAuthenticator
}
func jwtStepRequest(accessToken string) []byte {
return bsoncore.NewDocumentBuilder().
AppendString("jwt", accessToken).
Build()
}
func principalStepRequest(principal string) []byte {
doc := bsoncore.NewDocumentBuilder()
if principal != "" {
doc.AppendString("n", principal)
}
return doc.Build()
}
func (oos *oidcOneStep) Start() (string, []byte, error) {
return MongoDBOIDC, jwtStepRequest(oos.accessToken), nil
}
func (oos *oidcOneStep) Next(context.Context, []byte) ([]byte, error) {
return nil, newAuthError("unexpected step in OIDC authentication", nil)
}
func (*oidcOneStep) Completed() bool {
return true
}
func (ots *oidcTwoStep) Start() (string, []byte, error) {
return MongoDBOIDC, principalStepRequest(ots.oa.userName), nil
}
func (ots *oidcTwoStep) Next(ctx context.Context, msg []byte) ([]byte, error) {
var idpInfo IDPInfo
err := bson.Unmarshal(msg, &idpInfo)
if err != nil {
return nil, fmt.Errorf("error unmarshaling BSON document: %w", err)
}
accessToken, err := ots.oa.getAccessToken(ctx,
ots.conn,
&OIDCArgs{
Version: apiVersion,
// idpInfo is nil for machine callbacks in the current spec.
IDPInfo: &idpInfo,
// there is no way there could be a refresh token when there is no IDPInfo.
RefreshToken: nil,
},
// two-step callbacks are always human callbacks.
ots.oa.OIDCHumanCallback)
return jwtStepRequest(accessToken), err
}
func (*oidcTwoStep) Completed() bool {
return true
}
func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) {
env, ok := oa.AuthMechanismProperties[EnvironmentProp]
if !ok {
return nil, nil
}
switch env {
case AzureEnvironmentValue:
resource, ok := oa.AuthMechanismProperties[ResourceProp]
if !ok {
return nil, newAuthError(fmt.Sprintf("%q must be specified for Azure OIDC", ResourceProp), nil)
}
return getAzureOIDCCallback(oa.userName, resource, oa.httpClient), nil
case GCPEnvironmentValue:
resource, ok := oa.AuthMechanismProperties[ResourceProp]
if !ok {
return nil, newAuthError(fmt.Sprintf("%q must be specified for GCP OIDC", ResourceProp), nil)
}
return getGCPOIDCCallback(resource, oa.httpClient), nil
case K8SEnvironmentValue:
return k8sOIDCCallback, nil
}
return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", EnvironmentProp, env)
}
// getAzureOIDCCallback returns the callback for the Azure Identity Provider.
func getAzureOIDCCallback(clientID string, resource string, httpClient *http.Client) OIDCCallback {
// return the callback parameterized by the clientID and resource, also passing in the user
// configured httpClient.
return func(ctx context.Context, _ *OIDCArgs) (*OIDCCredential, error) {
resource = url.QueryEscape(resource)
var uri string
if clientID != "" {
uri = fmt.Sprintf("http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s&client_id=%s", resource, clientID)
} else {
uri = fmt.Sprintf("http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s", resource)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, newAuthError("error creating http request to Azure Identity Provider", err)
}
req.Header.Add("Metadata", "true")
req.Header.Add("Accept", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return nil, newAuthError("error getting access token from Azure Identity Provider", err)
}
defer resp.Body.Close()
var azureResp struct {
AccessToken string `json:"access_token"`
ExpiresOn int64 `json:"expires_on,string"`
}
if resp.StatusCode != http.StatusOK {
return nil, newAuthError(fmt.Sprintf("failed to get a valid response from Azure Identity Provider, http code: %d", resp.StatusCode), nil)
}
err = json.NewDecoder(resp.Body).Decode(&azureResp)
if err != nil {
return nil, newAuthError("failed parsing result from Azure Identity Provider", err)
}
expireTime := time.Unix(azureResp.ExpiresOn, 0)
return &OIDCCredential{
AccessToken: azureResp.AccessToken,
ExpiresAt: &expireTime,
}, nil
}
}
// getGCPOIDCCallback returns the callback for the GCP Identity Provider.
func getGCPOIDCCallback(resource string, httpClient *http.Client) OIDCCallback {
// return the callback parameterized by the clientID and resource, also passing in the user
// configured httpClient.
return func(ctx context.Context, _ *OIDCArgs) (*OIDCCredential, error) {
resource = url.QueryEscape(resource)
uri := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s", resource)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, newAuthError("error creating http request to GCP Identity Provider", err)
}
req.Header.Add("Metadata-Flavor", "Google")
resp, err := httpClient.Do(req)
if err != nil {
return nil, newAuthError("error getting access token from GCP Identity Provider", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, newAuthError(fmt.Sprintf("failed to get a valid response from GCP Identity Provider, http code: %d", resp.StatusCode), nil)
}
accessToken, err := io.ReadAll(resp.Body)
if err != nil {
return nil, newAuthError("failed parsing reading response from GCP Identity Provider", err)
}
return &OIDCCredential{
AccessToken: string(accessToken),
ExpiresAt: nil,
}, nil
}
}
// k8sOIDCCallbackfunc is the callback for the Kubernetes token provider.
func k8sOIDCCallback(context.Context, *OIDCArgs) (*OIDCCredential, error) {
// Check for the presence of the Azure and AWS token file path environment
// variables. If neither are set, use the GKE default token file path.
var path string
if p := os.Getenv("AZURE_FEDERATED_TOKEN_FILE"); p != "" {
path = p
} else if p := os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE"); p != "" {
path = p
} else {
path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
}
token, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("error reading OIDC token from %q: %w", path, err)
}
return &OIDCCredential{
AccessToken: string(token),
}, nil
}
func (oa *OIDCAuthenticator) getAccessToken(
ctx context.Context,
conn *mnet.Connection,
args *OIDCArgs,
callback OIDCCallback,
) (string, error) {
oa.mu.Lock()
defer oa.mu.Unlock()
if oa.accessToken != "" {
return oa.accessToken, nil
}
// Attempt to refresh the access token if a refresh token is available.
if args.RefreshToken != nil {
cred, err := callback(ctx, args)
if err == nil && cred != nil {
oa.accessToken = cred.AccessToken
oa.tokenGenID++
conn.SetOIDCTokenGenID(oa.tokenGenID)
oa.refreshToken = cred.RefreshToken
return cred.AccessToken, nil
}
oa.refreshToken = nil
args.RefreshToken = nil
}
// If we get here this means there either was no refresh token or the refresh token failed.
cred, err := callback(ctx, args)
if err != nil {
return "", err
}
// This line should never occur, if go conventions are followed, but it is a safety check such
// that we do not throw nil pointer errors to our users if they abuse the API.
if cred == nil {
return "", newAuthError("OIDC callback returned nil credential with no specified error", nil)
}
oa.accessToken = cred.AccessToken
oa.tokenGenID++
conn.SetOIDCTokenGenID(oa.tokenGenID)
oa.refreshToken = cred.RefreshToken
// always set the IdPInfo, in most cases, this should just be recopying the same pointer, or nil
// in the machine flow.
oa.idpInfo = args.IDPInfo
return cred.AccessToken, nil
}
// invalidateAccessToken invalidates the access token, if the force flag is set to true (which is
// only on a Reauth call) or if the tokenGenID of the connection is greater than or equal to the
// tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal,
// but this is a safety check, since extra invalidation is only a performance impact, not a
// correctness impact.
func (oa *OIDCAuthenticator) invalidateAccessToken(conn *mnet.Connection) {
oa.mu.Lock()
defer oa.mu.Unlock()
tokenGenID := conn.OIDCTokenGenID()
// If the connection used in a Reauth is a new connection it will not have a correct tokenGenID,
// it will instead be set to 0. In the absence of information, the only safe thing to do is to
// invalidate the cached accessToken.
if tokenGenID == 0 || tokenGenID >= oa.tokenGenID {
oa.accessToken = ""
conn.SetOIDCTokenGenID(0)
}
}
// Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the
// driver.Authenticator interface.
func (oa *OIDCAuthenticator) Reauth(ctx context.Context, cfg *driver.AuthConfig) error {
oa.invalidateAccessToken(cfg.Connection)
return oa.Auth(ctx, cfg)
}
// Auth authenticates the connection.
func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
var err error
if cfg == nil {
return newAuthError(fmt.Sprintf("config must be set for %q authentication", MongoDBOIDC), nil)
}
conn := cfg.Connection
oa.mu.Lock()
cachedAccessToken := oa.accessToken
cachedRefreshToken := oa.refreshToken
cachedIDPInfo := oa.idpInfo
oa.mu.Unlock()
if cachedAccessToken != "" {
err = ConductSaslConversation(ctx, cfg, sourceExternal, &oidcOneStep{
userName: oa.userName,
accessToken: cachedAccessToken,
})
if err == nil {
return nil
}
// this seems like it could be incorrect since we could be inavlidating an access token that
// has already been replaced by a different auth attempt, but the TokenGenID will prevernt
// that from happening.
oa.invalidateAccessToken(conn)
time.Sleep(invalidateSleepTimeout)
}
if oa.OIDCHumanCallback != nil {
return oa.doAuthHuman(ctx, cfg, oa.OIDCHumanCallback, cachedIDPInfo, cachedRefreshToken)
}
// Handle user provided or automatic provider machine callback.
var machineCallback OIDCCallback
if oa.OIDCMachineCallback != nil {
machineCallback = oa.OIDCMachineCallback
} else {
machineCallback, err = oa.providerCallback()
if err != nil {
return fmt.Errorf("error getting built-in OIDC provider: %w", err)
}
}
if machineCallback != nil {
return oa.doAuthMachine(ctx, cfg, machineCallback)
}
return newAuthError("no OIDC callback provided", nil)
}
func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *driver.AuthConfig, humanCallback OIDCCallback, idpInfo *IDPInfo, refreshToken *string) error {
// Ensure that the connection address is allowed by the allowed hosts.
err := oa.validateConnectionAddressWithAllowedHosts(cfg.Connection)
if err != nil {
return err
}
subCtx, cancel := context.WithTimeout(ctx, humanCallbackTimeout)
defer cancel()
// If the idpInfo exists, we can just do one step
if idpInfo != nil {
accessToken, err := oa.getAccessToken(subCtx,
cfg.Connection,
&OIDCArgs{
Version: apiVersion,
// idpInfo is nil for machine callbacks in the current spec.
IDPInfo: idpInfo,
RefreshToken: refreshToken,
},
humanCallback)
if err != nil {
return err
}
return ConductSaslConversation(
subCtx,
cfg,
sourceExternal,
&oidcOneStep{accessToken: accessToken},
)
}
// otherwise, we need the two step where we ask the server for the IdPInfo first.
ots := &oidcTwoStep{
conn: cfg.Connection,
oa: oa,
}
return ConductSaslConversation(subCtx, cfg, sourceExternal, ots)
}
func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.AuthConfig, machineCallback OIDCCallback) error {
subCtx, cancel := context.WithTimeout(ctx, machineCallbackTimeout)
accessToken, err := oa.getAccessToken(subCtx,
cfg.Connection,
&OIDCArgs{
Version: apiVersion,
// idpInfo is nil for machine callbacks in the current spec.
IDPInfo: nil,
RefreshToken: nil,
},
machineCallback)
cancel()
if err != nil {
return err
}
return ConductSaslConversation(
ctx,
cfg,
sourceExternal,
&oidcOneStep{accessToken: accessToken},
)
}
// CreateSpeculativeConversation creates a speculative conversation for OIDC authentication.
func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
oa.mu.Lock()
defer oa.mu.Unlock()
accessToken := oa.accessToken
if accessToken == "" {
return nil, nil // Skip speculative auth.
}
return newSaslConversation(&oidcOneStep{accessToken: accessToken}, sourceExternal, true), nil
}

View File

@@ -0,0 +1,78 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"net/http"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
)
// PLAIN is the mechanism name for PLAIN.
const PLAIN = "PLAIN"
func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
// TODO(GODRIVER-3317): The PLAIN specification says about auth source:
//
// "MUST be specified. Defaults to the database name if supplied on the
// connection string or $external."
//
// We should actually pass through the auth source, not always pass
// $external. If it's empty, we should default to $external.
//
// For example:
//
// source := cred.Source
// if source == "" {
// source = "$external"
// }
//
return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
}, nil
}
// PlainAuthenticator uses the PLAIN algorithm over SASL to authenticate a connection.
type PlainAuthenticator struct {
Username string
Password string
}
// Auth authenticates the connection.
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
return ConductSaslConversation(ctx, cfg, sourceExternal, &plainSaslClient{
username: a.Username,
password: a.Password,
})
}
// Reauth reauthenticates the connection.
func (a *PlainAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error {
return newAuthError("Plain authentication does not support reauthentication", nil)
}
type plainSaslClient struct {
username string
password string
}
var _ SaslClient = (*plainSaslClient)(nil)
func (c *plainSaslClient) Start() (string, []byte, error) {
b := []byte("\x00" + c.username + "\x00" + c.password)
return PLAIN, b, nil
}
func (c *plainSaslClient) Next(context.Context, []byte) ([]byte, error) {
return nil, newAuthError("unexpected server challenge", nil)
}
func (c *plainSaslClient) Completed() bool {
return true
}

View File

@@ -0,0 +1,173 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation"
)
// SaslClient is the client piece of a sasl conversation.
type SaslClient interface {
Start() (string, []byte, error)
Next(ctx context.Context, challenge []byte) ([]byte, error)
Completed() bool
}
// SaslClientCloser is a SaslClient that has resources to clean up.
type SaslClientCloser interface {
SaslClient
Close()
}
// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command.
type ExtraOptionsSaslClient interface {
StartCommandOptions() bsoncore.Document
}
// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the
// conversation can be executed in multi-step speculative fashion.
type saslConversation struct {
client SaslClient
source string
mechanism string
speculative bool
}
var _ SpeculativeConversation = (*saslConversation)(nil)
func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation {
authSource := source
if authSource == "" {
authSource = defaultAuthDB
}
return &saslConversation{
client: client,
source: authSource,
speculative: speculative,
}
}
// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used
// for speculative authentication.
func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) {
var payload []byte
var err error
sc.mechanism, payload, err = sc.client.Start()
if err != nil {
return nil, err
}
saslCmdElements := [][]byte{
bsoncore.AppendInt32Element(nil, "saslStart", 1),
bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
}
if sc.speculative {
// The "db" field is only appended for speculative auth because the hello command is executed against admin
// so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands
// will be executed against the auth source.
saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source))
}
if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok {
optionsDoc := extraOptionsClient.StartCommandOptions()
saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc))
}
return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil
}
type saslResponse struct {
ConversationID int `bson:"conversationId"`
Code int `bson:"code"`
Done bool `bson:"done"`
Payload []byte `bson:"payload"`
}
// Finish completes the conversation based on the first server response to authenticate the given connection.
func (sc *saslConversation) Finish(ctx context.Context, cfg *driver.AuthConfig, firstResponse bsoncore.Document) error {
if closer, ok := sc.client.(SaslClientCloser); ok {
defer closer.Close()
}
var saslResp saslResponse
err := bson.Unmarshal(firstResponse, &saslResp)
if err != nil {
fullErr := fmt.Errorf("unmarshal error: %w", err)
return newError(fullErr, sc.mechanism)
}
cid := saslResp.ConversationID
var payload []byte
var rdr bsoncore.Document
for {
if saslResp.Code != 0 {
return newError(err, sc.mechanism)
}
if saslResp.Done && sc.client.Completed() {
return nil
}
payload, err = sc.client.Next(ctx, saslResp.Payload)
if err != nil {
return newError(err, sc.mechanism)
}
if saslResp.Done && sc.client.Completed() {
return nil
}
doc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "saslContinue", 1),
bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
)
saslContinueCmd := operation.NewCommand(doc).
Database(sc.source).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err = saslContinueCmd.Execute(ctx)
if err != nil {
return newError(err, sc.mechanism)
}
rdr = saslContinueCmd.Result()
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
fullErr := fmt.Errorf("unmarshal error: %w", err)
return newError(fullErr, sc.mechanism)
}
}
}
// ConductSaslConversation runs a full SASL conversation to authenticate the given connection.
func ConductSaslConversation(ctx context.Context, cfg *driver.AuthConfig, authSource string, client SaslClient) error {
// Create a non-speculative SASL conversation.
conversation := newSaslConversation(client, authSource, false)
saslStartDoc, err := conversation.FirstMessage()
if err != nil {
return newError(err, conversation.mechanism)
}
saslStartCmd := operation.NewCommand(saslStartDoc).
Database(authSource).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
if err := saslStartCmd.Execute(ctx); err != nil {
return newError(err, conversation.mechanism)
}
return conversation.Finish(ctx, cfg, saslStartCmd.Result())
}

View File

@@ -0,0 +1,144 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Copyright (C) MongoDB, Inc. 2018-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"net/http"
"github.com/xdg-go/scram"
"github.com/xdg-go/stringprep"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
)
const (
// SCRAMSHA1 holds the mechanism name "SCRAM-SHA-1"
SCRAMSHA1 = "SCRAM-SHA-1"
// SCRAMSHA256 holds the mechanism name "SCRAM-SHA-256"
SCRAMSHA256 = "SCRAM-SHA-256"
)
// Additional options for the saslStart command to enable a shorter SCRAM conversation
var scramStartOptions bsoncore.Document = bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
)
func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-1 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA1,
source: source,
client: client,
}, nil
}
func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
if err != nil {
return nil, newAuthError("error SASLprepping password", err)
}
client, err := scram.SHA256.NewClientUnprepped(cred.Username, passprep, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-256 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA256,
source: source,
client: client,
}, nil
}
// ScramAuthenticator uses the SCRAM algorithm over SASL to authenticate a connection.
type ScramAuthenticator struct {
mechanism string
source string
client *scram.Client
}
var _ SpeculativeAuthenticator = (*ScramAuthenticator)(nil)
// Auth authenticates the provided connection by conducting a full SASL conversation.
func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
err := ConductSaslConversation(ctx, cfg, a.source, a.createSaslClient())
if err != nil {
return newAuthError("sasl conversation error", err)
}
return nil
}
// Reauth reauthenticates the connection.
func (a *ScramAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error {
return newAuthError("SCRAM does not support reauthentication", nil)
}
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return newSaslConversation(a.createSaslClient(), a.source, true), nil
}
func (a *ScramAuthenticator) createSaslClient() SaslClient {
return &scramSaslAdapter{
conversation: a.client.NewConversation(),
mechanism: a.mechanism,
}
}
type scramSaslAdapter struct {
mechanism string
conversation *scram.ClientConversation
}
var (
_ SaslClient = (*scramSaslAdapter)(nil)
_ ExtraOptionsSaslClient = (*scramSaslAdapter)(nil)
)
func (a *scramSaslAdapter) Start() (string, []byte, error) {
step, err := a.conversation.Step("")
if err != nil {
return a.mechanism, nil, err
}
return a.mechanism, []byte(step), nil
}
func (a *scramSaslAdapter) Next(_ context.Context, challenge []byte) ([]byte, error) {
step, err := a.conversation.Step(string(challenge))
if err != nil {
return nil, err
}
return []byte(step), nil
}
func (a *scramSaslAdapter) Completed() bool {
return a.conversation.Done()
}
func (*scramSaslAdapter) StartCommandOptions() bsoncore.Document {
return scramStartOptions
}

View File

@@ -0,0 +1,30 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"fmt"
"io"
// Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need
// to use MD5 here to implement the SCRAM specification.
/* #nosec G501 */
"crypto/md5"
)
const defaultAuthDB = "admin"
func mongoPasswordDigest(username, password string) string {
// Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to
// implement the SCRAM specification.
/* #nosec G401 */
h := md5.New()
_, _ = io.WriteString(h, username)
_, _ = io.WriteString(h, ":mongo:")
_, _ = io.WriteString(h, password)
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -0,0 +1,87 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"net/http"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation"
)
// MongoDBX509 is the mechanism name for MongoDBX509.
const MongoDBX509 = "MONGODB-X509"
func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
// TODO(GODRIVER-3309): Validate that cred.Source is either empty or
// "$external" to make validation uniform with other auth mechanisms that
// require Source to be "$external" (e.g. MONGODB-AWS, MONGODB-OIDC, etc).
return &MongoDBX509Authenticator{User: cred.Username}, nil
}
// MongoDBX509Authenticator uses X.509 certificates over TLS to authenticate a connection.
type MongoDBX509Authenticator struct {
User string
}
var _ SpeculativeAuthenticator = (*MongoDBX509Authenticator)(nil)
// x509 represents a X509 authentication conversation. This type implements the SpeculativeConversation interface so the
// conversation can be executed in multi-step speculative fashion.
type x509Conversation struct{}
var _ SpeculativeConversation = (*x509Conversation)(nil)
// FirstMessage returns the first message to be sent to the server.
func (c *x509Conversation) FirstMessage() (bsoncore.Document, error) {
return createFirstX509Message(), nil
}
// createFirstX509Message creates the first message for the X509 conversation.
func createFirstX509Message() bsoncore.Document {
elements := [][]byte{
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "mechanism", MongoDBX509),
}
return bsoncore.BuildDocument(nil, elements...)
}
// Finish implements the SpeculativeConversation interface and is a no-op because an X509 conversation only has one
// step.
func (c *x509Conversation) Finish(context.Context, *driver.AuthConfig, bsoncore.Document) error {
return nil
}
// CreateSpeculativeConversation creates a speculative conversation for X509 authentication.
func (a *MongoDBX509Authenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return &x509Conversation{}, nil
}
// Auth authenticates the provided connection by conducting an X509 authentication conversation.
func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
requestDoc := createFirstX509Message()
authCmd := operation.
NewCommand(requestDoc).
Database(sourceExternal).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err := authCmd.Execute(ctx)
if err != nil {
return newAuthError("round trip error", err)
}
return nil
}
// Reauth reauthenticates the connection.
func (a *MongoDBX509Authenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error {
return newAuthError("X509 does not support reauthentication", nil)
}

View File

@@ -0,0 +1,593 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/codecutil"
"go.mongodb.org/mongo-driver/v2/internal/csot"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// ErrNoCursor is returned by NewCursorResponse when the database response does
// not contain a cursor.
var ErrNoCursor = errors.New("database response does not contain a cursor")
// BatchCursor is a batch implementation of a cursor. It returns documents in entire batches instead
// of one at a time. An individual document cursor can be built on top of this batch cursor.
type BatchCursor struct {
clientSession *session.Client
clock *session.ClusterClock
comment any
encoderFn codecutil.EncoderFn
database string
collection string
id int64
err error
server Server
serverDescription description.Server
errorProcessor ErrorProcessor // This will only be set when pinning to a connection.
connection *mnet.Connection
batchSize int32
currentBatch *bsoncore.Iterator
firstBatch bool
cmdMonitor *event.CommandMonitor
postBatchResumeToken bsoncore.Document
crypt Crypt
serverAPI *ServerAPIOptions
// maxAwaitTime is only valid for tailable awaitData cursors. If this option
// is set, it will be used as the "maxTimeMS" field on getMore commands.
maxAwaitTime *time.Duration
// legacy server (< 3.2) fields
limit int32
numReturned int32 // number of docs returned by server
}
// CursorResponse represents the response from a command the results in a cursor. A BatchCursor can
// be constructed from a CursorResponse.
type CursorResponse struct {
Server Server
ErrorProcessor ErrorProcessor // This will only be set when pinning to a connection.
Connection *mnet.Connection
Desc description.Server
FirstBatch *bsoncore.Iterator
Database string
Collection string
ID int64
postBatchResumeToken bsoncore.Document
}
// ExtractCursorDocument retrieves cursor document from a database response. If the
// provided response does not contain a cursor, it returns ErrNoCursor.
func ExtractCursorDocument(response bsoncore.Document) (bsoncore.Document, error) {
cur, err := response.LookupErr("cursor")
if errors.Is(err, bsoncore.ErrElementNotFound) {
return nil, ErrNoCursor
}
if err != nil {
return nil, fmt.Errorf("error getting cursor from database response: %w", err)
}
curDoc, ok := cur.DocumentOK()
if !ok {
return nil, fmt.Errorf("cursor should be an embedded document but is BSON type %s", cur.Type)
}
return curDoc, nil
}
// NewCursorResponse constructs a cursor response from the given cursor document
// extracted from a database response.
//
// NewCursorResponse can be used within the ProcessResponse method for an operation.
func NewCursorResponse(response bsoncore.Document, info ResponseInfo) (CursorResponse, error) {
elems, err := response.Elements()
if err != nil {
return CursorResponse{}, fmt.Errorf("error getting elements from cursor: %w", err)
}
curresp := CursorResponse{Server: info.Server, Desc: info.ConnectionDescription}
for _, elem := range elems {
switch elem.Key() {
case "firstBatch":
arr, ok := elem.Value().ArrayOK()
if !ok {
return CursorResponse{}, fmt.Errorf("firstBatch should be an array but is a BSON %s", elem.Value().Type)
}
curresp.FirstBatch = &bsoncore.Iterator{List: arr}
case "ns":
ns, ok := elem.Value().StringValueOK()
if !ok {
return CursorResponse{}, fmt.Errorf("ns should be a string but is a BSON %s", elem.Value().Type)
}
database, collection, ok := strings.Cut(ns, ".")
if !ok {
return CursorResponse{}, errors.New("ns field must contain a valid namespace, but is missing '.'")
}
curresp.Database = database
curresp.Collection = collection
case "id":
id, ok := elem.Value().Int64OK()
if !ok {
return CursorResponse{}, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
}
curresp.ID = id
case "postBatchResumeToken":
token, ok := elem.Value().DocumentOK()
if !ok {
return CursorResponse{}, fmt.Errorf("post batch resume token should be a document but it is a BSON %s", elem.Value().Type)
}
curresp.postBatchResumeToken = token
}
}
// If the deployment is behind a load balancer and the cursor has a non-zero ID, pin the cursor to a connection and
// use the same connection to execute getMore and killCursors commands.
if driverutil.IsServerLoadBalanced(curresp.Desc) && curresp.ID != 0 {
// Cache the server as an ErrorProcessor to use when constructing deployments for cursor commands.
ep, ok := curresp.Server.(ErrorProcessor)
if !ok {
return CursorResponse{}, fmt.Errorf("expected Server used to establish a cursor to implement ErrorProcessor, but got %T", curresp.Server)
}
curresp.ErrorProcessor = ep
refConn := info.Connection.Pinner
if refConn == nil {
return CursorResponse{}, fmt.Errorf("expected Connection used to establish a cursor to implement PinnedConnection, but got %T", info.Connection)
}
if err := refConn.PinToCursor(); err != nil {
return CursorResponse{}, fmt.Errorf("error incrementing connection reference count when creating a cursor: %w", err)
}
curresp.Connection = info.Connection
}
return curresp, nil
}
// CursorOptions are extra options that are required to construct a BatchCursor.
type CursorOptions struct {
BatchSize int32
Comment bsoncore.Value
Limit int32
CommandMonitor *event.CommandMonitor
Crypt Crypt
ServerAPI *ServerAPIOptions
MarshalValueEncoderFn func(io.Writer) *bson.Encoder
// MaxAwaitTime is only valid for tailable awaitData cursors. If this option
// is set, it will be used as the "maxTimeMS" field on getMore commands.
MaxAwaitTime *time.Duration
}
// SetMaxAwaitTime will set the maxTimeMS value on getMore commands for
// tailable awaitData cursors.
func (cursorOptions *CursorOptions) SetMaxAwaitTime(dur time.Duration) {
cursorOptions.MaxAwaitTime = &dur
}
// NewBatchCursor creates a new BatchCursor from the provided parameters.
func NewBatchCursor(
cr CursorResponse,
clientSession *session.Client,
clock *session.ClusterClock,
opts CursorOptions,
) (*BatchCursor, error) {
firstBatch := cr.FirstBatch
bc := &BatchCursor{
clientSession: clientSession,
clock: clock,
comment: opts.Comment,
database: cr.Database,
collection: cr.Collection,
id: cr.ID,
server: cr.Server,
connection: cr.Connection,
errorProcessor: cr.ErrorProcessor,
batchSize: opts.BatchSize,
maxAwaitTime: opts.MaxAwaitTime,
cmdMonitor: opts.CommandMonitor,
firstBatch: true,
postBatchResumeToken: cr.postBatchResumeToken,
crypt: opts.Crypt,
serverAPI: opts.ServerAPI,
serverDescription: cr.Desc,
encoderFn: opts.MarshalValueEncoderFn,
}
if firstBatch != nil {
bc.numReturned = int32(firstBatch.Count())
}
bc.currentBatch = firstBatch
return bc, nil
}
// NewEmptyBatchCursor returns a batch cursor that is empty.
func NewEmptyBatchCursor() *BatchCursor {
return &BatchCursor{currentBatch: new(bsoncore.Iterator)}
}
// NewBatchCursorFromList returns a batch cursor with current batch set to an
// itertor that can traverse the BSON data contained within the array.
func NewBatchCursorFromList(array []byte) *BatchCursor {
return &BatchCursor{
currentBatch: &bsoncore.Iterator{List: array},
id: 0,
server: nil,
}
}
// ID returns the cursor ID for this batch cursor.
func (bc *BatchCursor) ID() int64 {
return bc.id
}
// Next indicates if there is another batch available. Returning false does not necessarily indicate
// that the cursor is closed. This method will return false when an empty batch is returned.
//
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
// is not a valid batch of documents available.
func (bc *BatchCursor) Next(ctx context.Context) bool {
if ctx == nil {
ctx = context.Background()
}
if bc.firstBatch {
bc.firstBatch = false
return !bc.currentBatch.Empty()
}
if bc.id == 0 || bc.server == nil {
return false
}
bc.getMore(ctx)
return !bc.currentBatch.Empty()
}
// Batch will return a DocumentSequence for the current batch of documents. The returned
// DocumentSequence is only valid until the next call to Next or Close.
func (bc *BatchCursor) Batch() *bsoncore.Iterator {
return bc.currentBatch
}
// Err returns the latest error encountered.
func (bc *BatchCursor) Err() error {
return bc.err
}
// Close closes this batch cursor.
func (bc *BatchCursor) Close(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
err := bc.KillCursor(ctx)
bc.id = 0
bc.currentBatch.List = nil
bc.currentBatch.Reset()
connErr := bc.unpinConnection()
if err == nil {
err = connErr
}
return err
}
func (bc *BatchCursor) unpinConnection() error {
if bc.connection == nil || bc.connection.Pinner == nil {
return nil
}
err := bc.connection.UnpinFromCursor()
closeErr := bc.connection.Close()
if err == nil && closeErr != nil {
err = closeErr
}
bc.connection = nil
return err
}
// Server returns the server for this cursor.
func (bc *BatchCursor) Server() Server {
return bc.server
}
func (bc *BatchCursor) clearBatch() {
bc.currentBatch.List = bc.currentBatch.List[:0]
}
// KillCursor kills cursor on server without closing batch cursor
func (bc *BatchCursor) KillCursor(ctx context.Context) error {
if bc.server == nil || bc.id == 0 {
return nil
}
return Operation{
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "killCursors", bc.collection)
dst = bsoncore.BuildArrayElement(dst, "cursors", bsoncore.Value{Type: bsoncore.TypeInt64, Data: bsoncore.AppendInt64(nil, bc.id)})
return dst, nil
},
Database: bc.database,
Deployment: bc.getOperationDeployment(),
Client: bc.clientSession,
Clock: bc.clock,
Legacy: LegacyKillCursors,
CommandMonitor: bc.cmdMonitor,
ServerAPI: bc.serverAPI,
// No read preference is passed to the killCursor command,
// resulting in the default read preference: "primaryPreferred".
// Since this could be confusing, and there is no requirement
// to use a read preference here, we omit it.
omitReadPreference: true,
}.Execute(ctx)
}
// calcGetMoreBatchSize calculates the number of documents to return in the
// response of a "getMore" operation based on the given limit, batchSize, and
// number of documents already returned. Returns false if a non-trivial limit is
// lower than or equal to the number of documents already returned.
func calcGetMoreBatchSize(bc BatchCursor) (int32, bool) {
gmBatchSize := bc.batchSize
// Account for legacy operations that don't support setting a limit.
if bc.limit != 0 && bc.numReturned+bc.batchSize >= bc.limit {
gmBatchSize = bc.limit - bc.numReturned
if gmBatchSize <= 0 {
return gmBatchSize, false
}
}
return gmBatchSize, true
}
func (bc *BatchCursor) getMore(ctx context.Context) {
bc.clearBatch()
if bc.id == 0 {
return
}
numToReturn, ok := calcGetMoreBatchSize(*bc)
if !ok {
if err := bc.Close(ctx); err != nil {
bc.err = err
}
return
}
bc.err = Operation{
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
// If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use
// send remaining TimeoutMS - minRoundTripTime allowing the server an
// opportunity to respond with an empty batch.
var maxTimeMS int64
if bc.maxAwaitTime != nil {
_, ctxDeadlineSet := ctx.Deadline()
if ctxDeadlineSet {
rttMonitor := bc.Server().RTTMonitor()
var ok bool
maxTimeMS, ok = driverutil.CalculateMaxTimeMS(ctx, rttMonitor.Min())
if !ok && maxTimeMS <= 0 {
return nil, fmt.Errorf(
"calculated server-side timeout (%v ms) is less than or equal to 0 (%v): %w",
maxTimeMS,
rttMonitor.Stats(),
ErrDeadlineWouldBeExceeded)
}
}
if !ctxDeadlineSet || bc.maxAwaitTime.Milliseconds() < maxTimeMS {
maxTimeMS = bc.maxAwaitTime.Milliseconds()
}
}
dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id)
dst = bsoncore.AppendStringElement(dst, "collection", bc.collection)
if numToReturn > 0 {
dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn)
}
if maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS)
}
comment, err := codecutil.MarshalValue(bc.comment, bc.encoderFn)
if err != nil {
return nil, fmt.Errorf("error marshaling comment as a BSON value: %w", err)
}
// The getMore command does not support commenting pre-4.4.
if comment.Type != bsoncore.Type(0) && bc.serverDescription.WireVersion.Max >= 9 {
dst = bsoncore.AppendValueElement(dst, "comment", comment)
}
return dst, nil
},
Database: bc.database,
Deployment: bc.getOperationDeployment(),
ProcessResponseFn: func(_ context.Context, response bsoncore.Document, _ ResponseInfo) error {
id, ok := response.Lookup("cursor", "id").Int64OK()
if !ok {
return fmt.Errorf("cursor.id should be an int64 but is a BSON %s", response.Lookup("cursor", "id").Type)
}
bc.id = id
batch, ok := response.Lookup("cursor", "nextBatch").ArrayOK()
if !ok {
return fmt.Errorf("cursor.nextBatch should be an array but is a BSON %s", response.Lookup("cursor", "nextBatch").Type)
}
bc.currentBatch.List = batch
bc.currentBatch.Reset()
// Required for legacy operations which don't support limit.
bc.numReturned += int32(bc.currentBatch.Count())
pbrt, err := response.LookupErr("cursor", "postBatchResumeToken")
if err != nil {
// I don't really understand why we don't set bc.err here
return nil
}
pbrtDoc, ok := pbrt.DocumentOK()
if !ok {
bc.err = fmt.Errorf("expected BSON type for post batch resume token to be EmbeddedDocument but got %s", pbrt.Type)
return nil
}
bc.postBatchResumeToken = pbrtDoc
return nil
},
Client: bc.clientSession,
Clock: bc.clock,
Legacy: LegacyGetMore,
CommandMonitor: bc.cmdMonitor,
Crypt: bc.crypt,
ServerAPI: bc.serverAPI,
// Omit the automatically-calculated maxTimeMS because setting maxTimeMS
// on a non-awaitData cursor causes a server error. For awaitData
// cursors, maxTimeMS is set when maxAwaitTime is specified by the above
// CommandFn.
OmitMaxTimeMS: true,
// No read preference is passed to the getMore command,
// resulting in the default read preference: "primaryPreferred".
// Since this could be confusing, and there is no requirement
// to use a read preference here, we omit it.
omitReadPreference: true,
}.Execute(ctx)
// Once the cursor has been drained, we can unpin the connection if one is currently pinned.
if bc.id == 0 {
err := bc.unpinConnection()
if err != nil && bc.err == nil {
bc.err = err
}
}
// If we're in load balanced mode and the pinned connection encounters a network error, we should not use it for
// future commands. Per the spec, the connection will not be unpinned until the cursor is actually closed, but
// we set the cursor ID to 0 to ensure the Close() call will not execute a killCursors command.
if driverErr, ok := bc.err.(Error); ok && driverErr.NetworkError() && bc.connection != nil {
bc.id = 0
}
// Required for legacy operations which don't support limit.
if bc.limit != 0 && bc.numReturned >= bc.limit {
// call KillCursor instead of Close because Close will clear out the data for the current batch.
err := bc.KillCursor(ctx)
if err != nil && bc.err == nil {
bc.err = err
}
}
}
// PostBatchResumeToken returns the latest seen post batch resume token.
func (bc *BatchCursor) PostBatchResumeToken() bsoncore.Document {
return bc.postBatchResumeToken
}
// SetBatchSize sets the batchSize for future getMore operations.
func (bc *BatchCursor) SetBatchSize(size int32) {
bc.batchSize = size
}
// SetMaxAwaitTime will set the maximum amount of time the server will allow the
// operations to execute. The server will error if this field is set but the
// cursor is not configured with awaitData=true.
//
// The time.Duration value passed by this setter will be converted and rounded
// down to the nearest millisecond.
func (bc *BatchCursor) SetMaxAwaitTime(dur time.Duration) {
bc.maxAwaitTime = &dur
}
// SetComment sets the comment for future getMore operations.
func (bc *BatchCursor) SetComment(comment any) {
bc.comment = comment
}
func (bc *BatchCursor) getOperationDeployment() Deployment {
if bc.connection != nil {
return &loadBalancedCursorDeployment{
errorProcessor: bc.errorProcessor,
conn: bc.connection,
}
}
return SingleServerDeployment{bc.server}
}
// MaxAwaitTime returns the maximum amount of time the server will allow
// the operations to execute. This is only valid for tailable awaitData cursors.
func (bc *BatchCursor) MaxAwaitTime() *time.Duration {
return bc.maxAwaitTime
}
// loadBalancedCursorDeployment is used as a Deployment for getMore and killCursors commands when pinning to a
// connection in load balanced mode. This type also functions as an ErrorProcessor to ensure that SDAM errors are
// handled for these commands in this mode.
type loadBalancedCursorDeployment struct {
errorProcessor ErrorProcessor
conn *mnet.Connection
}
var (
_ Deployment = (*loadBalancedCursorDeployment)(nil)
_ Server = (*loadBalancedCursorDeployment)(nil)
_ ErrorProcessor = (*loadBalancedCursorDeployment)(nil)
)
func (lbcd *loadBalancedCursorDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
return lbcd, nil
}
func (lbcd *loadBalancedCursorDeployment) Kind() description.TopologyKind {
return description.TopologyKindLoadBalanced
}
func (lbcd *loadBalancedCursorDeployment) Connection(context.Context) (*mnet.Connection, error) {
return lbcd.conn, nil
}
// RTTMonitor implements the driver.Server interface.
func (lbcd *loadBalancedCursorDeployment) RTTMonitor() RTTMonitor {
return &csot.ZeroRTTMonitor{}
}
func (lbcd *loadBalancedCursorDeployment) ProcessError(err error, desc mnet.Describer) ProcessErrorResult {
return lbcd.errorProcessor.ProcessError(err, desc)
}
// GetServerSelectionTimeout returns zero as a server selection timeout is not
// applicable for load-balanced cursor deployments.
func (*loadBalancedCursorDeployment) GetServerSelectionTimeout() time.Duration {
return 0
}

View File

@@ -0,0 +1,116 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"io"
"strconv"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage"
)
// Batches contains the necessary information to batch split an operation. This is only used for write
// operations.
type Batches struct {
Identifier string
Documents []bsoncore.Document
Ordered *bool
offset int
}
var _ OperationBatches = &Batches{}
// AppendBatchSequence appends dst with document sequence of batches as long as the limits of max count, max
// document size, or total size allows. It returns the number of batches appended, the new appended slice, and
// any error raised. It returns the origenal input slice if nothing can be appends within the limits.
func (b *Batches) AppendBatchSequence(dst []byte, maxCount, totalSize int) (int, []byte, error) {
if b.Size() == 0 {
return 0, dst, io.EOF
}
l := len(dst)
var idx int32
dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence)
idx, dst = bsoncore.ReserveLength(dst)
dst = append(dst, b.Identifier...)
dst = append(dst, 0x00)
var size int
var n int
for i := b.offset; i < len(b.Documents); i++ {
if n == maxCount {
break
}
doc := b.Documents[i]
size += len(doc)
if size > totalSize {
break
}
dst = append(dst, doc...)
n++
}
if n == 0 {
return 0, dst[:l], nil
}
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
return n, dst, nil
}
// AppendBatchArray appends dst with array of batches as long as the limits of max count, max document size, or
// total size allows. It returns the number of batches appended, the new appended slice, and any error raised. It
// returns the origenal input slice if nothing can be appends within the limits.
func (b *Batches) AppendBatchArray(dst []byte, maxCount, totalSize int) (int, []byte, error) {
if b.Size() == 0 {
return 0, dst, io.EOF
}
l := len(dst)
aidx, dst := bsoncore.AppendArrayElementStart(dst, b.Identifier)
var size int
var n int
for i := b.offset; i < len(b.Documents); i++ {
if n == maxCount {
break
}
doc := b.Documents[i]
size += len(doc)
if size > totalSize {
break
}
dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(n), doc)
n++
}
if n == 0 {
return 0, dst[:l], nil
}
var err error
dst, err = bsoncore.AppendArrayEnd(dst, aidx)
if err != nil {
return 0, nil, err
}
return n, dst, nil
}
// IsOrdered indicates if the batches are ordered.
func (b *Batches) IsOrdered() *bool {
return b.Ordered
}
// AdvanceBatches advances the batches with the given input.
func (b *Batches) AdvanceBatches(n int) {
b.offset += n
if b.offset > len(b.Documents) {
b.offset = len(b.Documents)
}
}
// Size returns the size of batches remained.
func (b *Batches) Size() int {
if b.offset > len(b.Documents) {
return 0
}
return len(b.Documents) - b.offset
}

View File

@@ -0,0 +1,194 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"
"github.com/klauspost/compress/snappy"
"github.com/klauspost/compress/zstd"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage"
)
// CompressionOpts holds settings for how to compress a payload
type CompressionOpts struct {
Compressor wiremessage.CompressorID
ZlibLevel int
ZstdLevel int
UncompressedSize int32
}
// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil
// destination writer. It panics on any errors and should only be used at
// package initialization time.
func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
enc, err := zstd.NewWriter(
nil,
zstd.WithWindowSize(8<<20), // Set window size to 8MB.
zstd.WithEncoderLevel(lvl),
)
if err != nil {
panic(err)
}
return enc
}
var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
0: nil, // zstd.speedNotSet
zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest),
zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault),
zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression),
}
func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
return zstdEncoders[level], nil
}
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zstd compression level: %d", level)
}
// zlibEncodersOffset is the offset into the zlibEncoders array for a given
// compression level.
const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2
var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool
func getZlibEncoder(level int) (*zlibEncoder, error) {
if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
return enc, nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
}
enc := &zlibEncoder{writer: writer, level: level}
return enc, nil
}
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zlib compression level: %d", level)
}
func putZlibEncoder(enc *zlibEncoder) {
if enc != nil {
zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
}
}
type zlibEncoder struct {
writer *zlib.Writer
buf bytes.Buffer
level int
}
func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
defer putZlibEncoder(e)
e.buf.Reset()
e.writer.Reset(&e.buf)
_, err := e.writer.Write(src)
if err != nil {
return nil, err
}
err = e.writer.Close()
if err != nil {
return nil, err
}
dst = append(dst[:0], e.buf.Bytes()...)
return dst, nil
}
var zstdBufPool = sync.Pool{
New: func() any {
s := make([]byte, 0)
return &s
},
}
// CompressPayload takes a byte slice and compresses it according to the options passed
func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
case wiremessage.CompressorSnappy:
return snappy.Encode(nil, in), nil
case wiremessage.CompressorZLib:
encoder, err := getZlibEncoder(opts.ZlibLevel)
if err != nil {
return nil, err
}
return encoder.Encode(nil, in)
case wiremessage.CompressorZstd:
encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
if err != nil {
return nil, err
}
ptr := zstdBufPool.Get().(*[]byte)
b := encoder.EncodeAll(in, *ptr)
dst := make([]byte, len(b))
copy(dst, b)
*ptr = b[:0]
zstdBufPool.Put(ptr)
return dst, nil
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
}
var zstdReaderPool = sync.Pool{
New: func() any {
r, _ := zstd.NewReader(nil)
return r
},
}
// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
case wiremessage.CompressorSnappy:
l, err := snappy.DecodedLen(in)
if err != nil {
return nil, fmt.Errorf("decoding compressed length %w", err)
} else if int32(l) != opts.UncompressedSize {
return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
}
out := make([]byte, opts.UncompressedSize)
return snappy.Decode(out, in)
case wiremessage.CompressorZLib:
r, err := zlib.NewReader(bytes.NewReader(in))
if err != nil {
return nil, err
}
out := make([]byte, opts.UncompressedSize)
if _, err := io.ReadFull(r, out); err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
return out, nil
case wiremessage.CompressorZstd:
buf := make([]byte, 0, opts.UncompressedSize)
// Using a pool here is about ~20% faster
// than using a single global zstd.Reader
r := zstdReaderPool.Get().(*zstd.Decoder)
out, err := r.DecodeAll(in, buf)
zstdReaderPool.Put(r)
return out, err
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,416 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"crypto/tls"
"fmt"
"strings"
"time"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt/options"
)
const (
defaultKmsPort = 443
defaultKmsTimeout = 10 * time.Second
)
// CollectionInfoFn is a callback used to retrieve collection information.
type CollectionInfoFn func(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error)
// KeyRetrieverFn is a callback used to retrieve keys from the key vault.
type KeyRetrieverFn func(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error)
// MarkCommandFn is a callback used to add encryption markings to a command.
type MarkCommandFn func(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
// CryptOptions specifies options to configure a Crypt instance.
type CryptOptions struct {
MongoCrypt *mongocrypt.MongoCrypt
CollInfoFn CollectionInfoFn
KeyFn KeyRetrieverFn
MarkFn MarkCommandFn
TLSConfig map[string]*tls.Config
BypassAutoEncryption bool
BypassQueryAnalysis bool
}
// Crypt is an interface implemented by types that can encrypt and decrypt instances of
// bsoncore.Document.
//
// Users should rely on the driver's crypt type (used by default) for encryption and decryption
// unless they are perfectly confident in another implementation of Crypt.
type Crypt interface {
// Encrypt encrypts the given command.
Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
// Decrypt decrypts the given command response.
Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
// CreateDataKey creates a data key using the given KMS provider and options.
CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
// EncryptExplicit encrypts the given value with the given options.
EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
// EncryptExplicitExpression encrypts the given expression with the given options.
EncryptExplicitExpression(ctx context.Context, val bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error)
// DecryptExplicit decrypts the given encrypted value.
DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
// Close cleans up any resources associated with the Crypt instance.
Close()
// BypassAutoEncryption returns true if auto-encryption should be bypassed.
BypassAutoEncryption() bool
// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents
// to be returned as a slice of bsoncore.Document.
RewrapDataKey(ctx context.Context, filter []byte, opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error)
}
// crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
// and decryption.
type crypt struct {
mongoCrypt *mongocrypt.MongoCrypt
collInfoFn CollectionInfoFn
keyFn KeyRetrieverFn
markFn MarkCommandFn
tlsConfig map[string]*tls.Config
bypassAutoEncryption bool
}
// NewCrypt creates a new Crypt instance configured with the given AutoEncryptionOptions.
func NewCrypt(opts *CryptOptions) Crypt {
c := &crypt{
mongoCrypt: opts.MongoCrypt,
collInfoFn: opts.CollInfoFn,
keyFn: opts.KeyFn,
markFn: opts.MarkFn,
tlsConfig: opts.TLSConfig,
bypassAutoEncryption: opts.BypassAutoEncryption,
}
return c
}
// Encrypt encrypts the given command.
func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
if c.bypassAutoEncryption {
return cmd, nil
}
cryptCtx, err := c.mongoCrypt.CreateEncryptionContext(db, cmd)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
return c.executeStateMachine(ctx, cryptCtx, db)
}
// Decrypt decrypts the given command response.
func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
return c.executeStateMachine(ctx, cryptCtx, "")
}
// CreateDataKey creates a data key using the given KMS provider and options.
func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
return c.executeStateMachine(ctx, cryptCtx, "")
}
// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents to
// be returned as a slice of bsoncore.Document.
func (c *crypt) RewrapDataKey(ctx context.Context, filter []byte,
opts *options.RewrapManyDataKeyOptions,
) ([]bsoncore.Document, error) {
cryptCtx, err := c.mongoCrypt.RewrapDataKeyContext(filter, opts)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
rewrappedBSON, err := c.executeStateMachine(ctx, cryptCtx, "")
if err != nil {
return nil, err
}
if rewrappedBSON == nil {
return nil, nil
}
// mongocrypt_ctx_rewrap_many_datakey_init wraps the documents in a BSON of the form { "v": [(BSON document), ...] }
// where each BSON document in the slice is a document containing a rewrapped datakey.
rewrappedDocumentBytes, err := rewrappedBSON.LookupErr("v")
if err != nil {
return nil, err
}
// Parse the resulting BSON as individual documents.
rewrappedDocsArray, ok := rewrappedDocumentBytes.ArrayOK()
if !ok {
return nil, fmt.Errorf("expected results from mongocrypt_ctx_rewrap_many_datakey_init to be an array")
}
rewrappedDocumentValues, err := rewrappedDocsArray.Values()
if err != nil {
return nil, err
}
rewrappedDocuments := []bsoncore.Document{}
for _, rewrappedDocumentValue := range rewrappedDocumentValues {
if rewrappedDocumentValue.Type != bsoncore.TypeEmbeddedDocument {
// If a value in the document's array returned by mongocrypt is anything other than an embedded document,
// then something is wrong and we should terminate the routine.
return nil, fmt.Errorf("expected value of type %q, got: %q",
bsoncore.TypeEmbeddedDocument.String(),
rewrappedDocumentValue.Type.String())
}
rewrappedDocuments = append(rewrappedDocuments, rewrappedDocumentValue.Document())
}
return rewrappedDocuments, nil
}
// EncryptExplicit encrypts the given value with the given options.
func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendValueElement(doc, "v", val)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionContext(doc, opts)
if err != nil {
return 0, nil, err
}
defer cryptCtx.Close()
res, err := c.executeStateMachine(ctx, cryptCtx, "")
if err != nil {
return 0, nil, err
}
sub, data := res.Lookup("v").Binary()
return sub, data, nil
}
// EncryptExplicitExpression encrypts the given expression with the given options.
func (c *crypt) EncryptExplicitExpression(ctx context.Context, expr bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendDocumentElement(doc, "v", expr)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionExpressionContext(doc, opts)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
res, err := c.executeStateMachine(ctx, cryptCtx, "")
if err != nil {
return nil, err
}
encryptedExpr := res.Lookup("v").Document()
return encryptedExpr, nil
}
// DecryptExplicit decrypts the given encrypted value.
func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
cryptCtx, err := c.mongoCrypt.CreateExplicitDecryptionContext(doc)
if err != nil {
return bsoncore.Value{}, err
}
defer cryptCtx.Close()
res, err := c.executeStateMachine(ctx, cryptCtx, "")
if err != nil {
return bsoncore.Value{}, err
}
return res.Lookup("v"), nil
}
// Close cleans up any resources associated with the Crypt instance.
func (c *crypt) Close() {
c.mongoCrypt.Close()
}
func (c *crypt) BypassAutoEncryption() bool {
return c.bypassAutoEncryption
}
func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
var err error
for {
state := cryptCtx.State()
switch state {
case mongocrypt.NeedMongoCollInfo:
err = c.collectionInfo(ctx, cryptCtx, db)
case mongocrypt.NeedMongoMarkings:
err = c.markCommand(ctx, cryptCtx, db)
case mongocrypt.NeedMongoKeys:
err = c.retrieveKeys(ctx, cryptCtx)
case mongocrypt.NeedKms:
err = c.decryptKeys(cryptCtx)
case mongocrypt.Ready:
return cryptCtx.Finish()
case mongocrypt.Done:
return nil, nil
case mongocrypt.NeedKmsCredentials:
err = c.provideKmsProviders(ctx, cryptCtx)
default:
return nil, fmt.Errorf("invalid Crypt state: %v", state)
}
if err != nil {
return nil, err
}
}
}
func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
op, err := cryptCtx.NextOperation()
if err != nil {
return err
}
collInfo, err := c.collInfoFn(ctx, db, op)
if err != nil {
return err
}
if collInfo != nil {
if err = cryptCtx.AddOperationResult(collInfo); err != nil {
return err
}
}
return cryptCtx.CompleteOperation()
}
func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
op, err := cryptCtx.NextOperation()
if err != nil {
return err
}
markedCmd, err := c.markFn(ctx, db, op)
if err != nil {
return err
}
if err = cryptCtx.AddOperationResult(markedCmd); err != nil {
return err
}
return cryptCtx.CompleteOperation()
}
func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
op, err := cryptCtx.NextOperation()
if err != nil {
return err
}
keys, err := c.keyFn(ctx, op)
if err != nil {
return err
}
for _, key := range keys {
if err = cryptCtx.AddOperationResult(key); err != nil {
return err
}
}
return cryptCtx.CompleteOperation()
}
func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
for {
kmsCtx := cryptCtx.NextKmsContext()
if kmsCtx == nil {
break
}
if err := c.decryptKey(kmsCtx); err != nil {
return err
}
}
return cryptCtx.FinishKmsContexts()
}
func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
host, err := kmsCtx.HostName()
if err != nil {
return err
}
msg, err := kmsCtx.Message()
if err != nil {
return err
}
// add a port to the address if it's not already present
addr := host
if idx := strings.IndexByte(host, ':'); idx == -1 {
addr = fmt.Sprintf("%s:%d", host, defaultKmsPort)
}
kmsProvider := kmsCtx.KMSProvider()
tlsCfg := c.tlsConfig[kmsProvider]
if tlsCfg == nil {
tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
}
conn, err := tls.Dial("tcp", addr, tlsCfg)
if err != nil {
return err
}
defer func() {
_ = conn.Close()
}()
if err = conn.SetWriteDeadline(time.Now().Add(defaultKmsTimeout)); err != nil {
return err
}
if _, err = conn.Write(msg); err != nil {
return err
}
for {
bytesNeeded := kmsCtx.BytesNeeded()
if bytesNeeded == 0 {
return nil
}
res := make([]byte, bytesNeeded)
bytesRead, err := conn.Read(res)
if err != nil {
return kmsCtx.RequestError()
}
if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
return err
}
}
}
func (c *crypt) provideKmsProviders(ctx context.Context, cryptCtx *mongocrypt.Context) error {
kmsProviders, err := c.mongoCrypt.GetKmsProviders(ctx)
if err != nil {
return err
}
return cryptCtx.ProvideKmsProviders(kmsProviders)
}

View File

@@ -0,0 +1,144 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import (
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo/address"
"go.mongodb.org/mongo-driver/v2/tag"
)
// ServerKind represents the type of a single server in a topology.
type ServerKind uint32
// These constants are the possible types of servers.
const (
ServerKindStandalone ServerKind = 1
ServerKindRSMember ServerKind = 2
ServerKindRSPrimary ServerKind = 4 + ServerKindRSMember
ServerKindRSSecondary ServerKind = 8 + ServerKindRSMember
ServerKindRSArbiter ServerKind = 16 + ServerKindRSMember
ServerKindRSGhost ServerKind = 32 + ServerKindRSMember
ServerKindMongos ServerKind = 256
ServerKindLoadBalancer ServerKind = 512
)
// UnknownStr represents an unknown server kind.
const UnknownStr = "Unknown"
// String returns a stringified version of the kind or "Unknown" if the kind is
// invalid.
func (kind ServerKind) String() string {
switch kind {
case ServerKindStandalone:
return "Standalone"
case ServerKindRSMember:
return "RSOther"
case ServerKindRSPrimary:
return "RSPrimary"
case ServerKindRSSecondary:
return "RSSecondary"
case ServerKindRSArbiter:
return "RSArbiter"
case ServerKindRSGhost:
return "RSGhost"
case ServerKindMongos:
return "Mongos"
case ServerKindLoadBalancer:
return "LoadBalancer"
}
return UnknownStr
}
// Unknown is an unknown server or topology kind.
const Unknown = 0
// TopologyVersion represents a software version.
type TopologyVersion struct {
ProcessID bson.ObjectID
Counter int64
}
// VersionRange represents a range of versions.
type VersionRange struct {
Min int32
Max int32
}
// Server contains information about a node in a cluster. This is created from
// hello command responses. If the value of the Kind field is LoadBalancer, only
// the Addr and Kind fields will be set. All other fields will be set to the
// zero value of the field's type.
type Server struct {
Addr address.Address
Arbiters []string
AverageRTT time.Duration
AverageRTTSet bool
Compression []string // compression methods returned by server
CanonicalAddr address.Address
ElectionID bson.ObjectID
HeartbeatInterval time.Duration
HelloOK bool
Hosts []string
IsCryptd bool
LastError error
LastUpdateTime time.Time
LastWriteTime time.Time
MaxBatchCount uint32
MaxDocumentSize uint32
MaxMessageSize uint32
Members []address.Address
Passives []string
Passive bool
Primary address.Address
ReadOnly bool
ServiceID *bson.ObjectID // Only set for servers that are deployed behind a load balancer.
SessionTimeoutMinutes *int64
SetName string
SetVersion uint32
Tags tag.Set
TopologyVersion *TopologyVersion
Kind ServerKind
WireVersion *VersionRange
}
func (s Server) String() string {
str := fmt.Sprintf("Addr: %s, Type: %s", s.Addr, s.Kind)
if len(s.Tags) != 0 {
str += fmt.Sprintf(", Tag sets: %s", s.Tags)
}
if s.AverageRTTSet {
str += fmt.Sprintf(", Average RTT: %d", s.AverageRTT)
}
if s.LastError != nil {
str += fmt.Sprintf(", Last error: %s", s.LastError)
}
return str
}
// SelectedServer augments the Server type by also including the TopologyKind of
// the topology that includes the server. This type should be used to track the
// state of a server that was selected to perform an operation.
type SelectedServer struct {
Server
Kind TopologyKind
}
// ServerSelector is an interface implemented by types that can perform server
// selection given a topology description and list of candidate servers. The
// selector should filter the provided candidates list and return a subset that
// matches some criteria.
type ServerSelector interface {
SelectServer(Topology, []Server) ([]Server, error)
}

View File

@@ -0,0 +1,60 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import "fmt"
// TopologyKind represents a specific topology configuration.
type TopologyKind uint32
// These constants are the available topology configurations.
const (
TopologyKindSingle TopologyKind = 1
TopologyKindReplicaSet TopologyKind = 2
TopologyKindReplicaSetNoPrimary TopologyKind = 4 + TopologyKindReplicaSet
TopologyKindReplicaSetWithPrimary TopologyKind = 8 + TopologyKindReplicaSet
TopologyKindSharded TopologyKind = 256
TopologyKindLoadBalanced TopologyKind = 512
)
// Topology contains information about a MongoDB cluster.
type Topology struct {
Servers []Server
SetName string
Kind TopologyKind
SessionTimeoutMinutes *int64
CompatibilityErr error
}
// String implements the Stringer interface.
func (t Topology) String() string {
var serversStr string
for _, s := range t.Servers {
serversStr += "{ " + s.String() + " }, "
}
return fmt.Sprintf("Type: %s, Servers: [%s]", t.Kind, serversStr)
}
// String implements the fmt.Stringer interface.
func (kind TopologyKind) String() string {
switch kind {
case TopologyKindSingle:
return "Single"
case TopologyKindReplicaSet:
return "ReplicaSet"
case TopologyKindReplicaSetNoPrimary:
return "ReplicaSetNoPrimary"
case TopologyKindReplicaSetWithPrimary:
return "ReplicaSetWithPrimary"
case TopologyKindSharded:
return "Sharded"
case TopologyKindLoadBalanced:
return "LoadBalanced"
}
return "Unknown"
}

View File

@@ -0,0 +1,156 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package dns is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package dns
import (
"errors"
"fmt"
"net"
"runtime"
"strings"
)
// Resolver resolves DNS records.
type Resolver struct {
// Holds the functions to use for DNS lookups
LookupSRV func(string, string, string) (string, []*net.SRV, error)
LookupTXT func(string) ([]string, error)
}
// DefaultResolver is a Resolver that uses the default Resolver from the net package.
var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}
// ParseHosts uses the srv string and service name to get the hosts.
func (r *Resolver) ParseHosts(host string, srvName string, stopOnErr bool) ([]string, error) {
parsedHosts := strings.Split(host, ",")
if len(parsedHosts) != 1 {
return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
}
return r.fetchSeedlistFromSRV(parsedHosts[0], srvName, stopOnErr)
}
// GetConnectionArgsFromTXT gets the TXT record associated with the host and returns the connection arguments.
func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
var connectionArgsFromTXT []string
// error ignored because not finding a TXT record should not be
// considered an error.
recordsFromTXT, _ := r.LookupTXT(host)
// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
// It will currently incorrectly concatenate multiple TXT records to one
// on windows.
if runtime.GOOS == "windows" {
recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
}
if len(recordsFromTXT) > 1 {
return nil, errors.New("multiple records from TXT not supported")
}
if len(recordsFromTXT) > 0 {
connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
err := validateTXTResult(connectionArgsFromTXT)
if err != nil {
return nil, err
}
}
return connectionArgsFromTXT, nil
}
func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr bool) ([]string, error) {
var err error
_, _, err = net.SplitHostPort(host)
if err == nil {
// we were able to successfully extract a port from the host,
// but should not be able to when using SRV
return nil, fmt.Errorf("URI with srv must not include a port number")
}
// default to "mongodb" as service name if not supplied
if srvName == "" {
srvName = "mongodb"
}
_, addresses, err := r.LookupSRV(srvName, "tcp", host)
if err != nil && strings.Contains(err.Error(), "cannot unmarshal DNS message") {
return nil, fmt.Errorf("see https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo#hdr-Potential_DNS_Issues: %w", err)
} else if err != nil {
return nil, err
}
trimmedHost := strings.TrimSuffix(host, ".")
parsedHosts := make([]string, 0, len(addresses))
for _, address := range addresses {
trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
err := validateSRVResult(trimmedAddressTarget, trimmedHost)
if err != nil {
if stopOnErr {
return nil, err
}
continue
}
parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
}
return parsedHosts, nil
}
func validateSRVResult(recordFromSRV, inputHostName string) error {
separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".")
separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".")
if l := len(separatedInputDomain); l < 3 && len(separatedRecord) <= l {
return fmt.Errorf("server record (%d levels) should have more domain levels than parent URI (%d levels)", l, len(separatedRecord))
}
if len(separatedRecord) < len(separatedInputDomain) {
return errors.New("domain suffix from SRV record not matched input domain")
}
inputDomainSuffix := separatedInputDomain
if len(inputDomainSuffix) > 2 {
inputDomainSuffix = inputDomainSuffix[1:]
}
domainSuffixOffset := len(separatedRecord) - len(inputDomainSuffix)
recordDomainSuffix := separatedRecord[domainSuffixOffset:]
for ix, label := range inputDomainSuffix {
if label != recordDomainSuffix[ix] {
return errors.New("domain suffix from SRV record not matched input domain")
}
}
return nil
}
var allowedTXTOptions = map[string]struct{}{
"authsource": {},
"replicaset": {},
"loadbalanced": {},
}
func validateTXTResult(paramsFromTXT []string) error {
for _, param := range paramsFromTXT {
kv := strings.SplitN(param, "=", 2)
if len(kv) != 2 {
return errors.New("invalid TXT record")
}
key := strings.ToLower(kv[0])
if _, ok := allowedTXTOptions[key]; !ok {
return fmt.Errorf("cannot specify option '%s' in TXT record", kv[0])
}
}
return nil
}

View File

@@ -0,0 +1,294 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package driver is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package driver
import (
"context"
"time"
"go.mongodb.org/mongo-driver/v2/internal/csot"
"go.mongodb.org/mongo-driver/v2/mongo/address"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// AuthConfig holds the information necessary to perform an authentication attempt.
// this was moved from the auth package to avoid a circular dependency. The auth package
// reexports this under the old name to avoid breaking the public api.
type AuthConfig struct {
Description description.Server
Connection *mnet.Connection
ClusterClock *session.ClusterClock
HandshakeInfo HandshakeInformation
ServerAPI *ServerAPIOptions
}
// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be
// nil in the OIDCArgs for the Machine flow.
type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error)
// OIDCArgs contains the arguments for the OIDC callback.
type OIDCArgs struct {
Version int
IDPInfo *IDPInfo
RefreshToken *string
}
// OIDCCredential contains the access token and refresh token.
type OIDCCredential struct {
AccessToken string
ExpiresAt *time.Time
RefreshToken *string
}
// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider.
type IDPInfo struct {
Issuer string `bson:"issuer"`
ClientID string `bson:"clientId"`
RequestScopes []string `bson:"requestScopes"`
}
// Authenticator handles authenticating a connection. The implementers of this interface
// are all in the auth package. Most authentication mechanisms do not allow for Reauth,
// but this is included in the interface so that whenever a new mechanism is added, it
// must be explicitly considered.
type Authenticator interface {
// Auth authenticates the connection.
Auth(context.Context, *AuthConfig) error
Reauth(context.Context, *AuthConfig) error
}
// Cred is a user's credential.
type Cred struct {
Source string
Username string
Password string
PasswordSet bool
Props map[string]string
OIDCMachineCallback OIDCCallback
OIDCHumanCallback OIDCCallback
}
// Deployment is implemented by types that can select a server from a deployment.
type Deployment interface {
SelectServer(context.Context, description.ServerSelector) (Server, error)
Kind() description.TopologyKind
// GetServerSelectionTimeout returns a timeout that should be used to set a
// deadline for server selection. This logic is not handleded internally by
// the ServerSelector, as a resulting deadline may be applicable by follow-up
// operations such as checking out a connection.
GetServerSelectionTimeout() time.Duration
}
// Connector represents a type that can connect to a server.
type Connector interface {
Connect() error
}
// Disconnector represents a type that can disconnect from a server.
type Disconnector interface {
Disconnect(context.Context) error
}
// Subscription represents a subscription to topology updates. A subscriber can receive updates through the
// Updates field.
type Subscription struct {
Updates <-chan description.Topology
ID uint64
}
// Subscriber represents a type to which another type can subscribe. A subscription contains a channel that
// is updated with topology descriptions.
type Subscriber interface {
Subscribe() (*Subscription, error)
Unsubscribe(*Subscription) error
}
// Server represents a MongoDB server. Implementations should pool connections and handle the
// retrieving and returning of connections.
type Server interface {
Connection(context.Context) (*mnet.Connection, error)
// RTTMonitor returns the round-trip time monitor associated with this server.
RTTMonitor() RTTMonitor
}
// RTTMonitor represents a round-trip-time monitor.
type RTTMonitor interface {
// EWMA returns the exponentially weighted moving average observed round-trip time.
EWMA() time.Duration
// Min returns the minimum observed round-trip time over the window period.
Min() time.Duration
// Stats returns stringified stats of the current state of the monitor.
Stats() string
}
var _ RTTMonitor = &csot.ZeroRTTMonitor{}
// LocalAddresser is a type that is able to supply its local address
type LocalAddresser interface {
LocalAddress() address.Address
}
// Expirable represents an expirable object.
type Expirable interface {
Expire() error
Alive() bool
}
// ProcessErrorResult represents the result of a ErrorProcessor.ProcessError() call. Exact values for this type can be
// checked directly (e.g. res == ServerMarkedUnknown), but it is recommended that applications use the ServerChanged()
// function instead.
type ProcessErrorResult int
const (
// NoChange indicates that the error did not affect the state of the server.
NoChange ProcessErrorResult = iota
// ServerMarkedUnknown indicates that the error only resulted in the server being marked as Unknown.
ServerMarkedUnknown
// ConnectionPoolCleared indicates that the error resulted in the server being marked as Unknown and its connection
// pool being cleared.
ConnectionPoolCleared
)
// ErrorProcessor implementations can handle processing errors, which may modify their internal state.
// If this type is implemented by a Server, then Operation.Execute will call it's ProcessError
// method after it decodes a wire message.
type ErrorProcessor interface {
ProcessError(err error, desc mnet.Describer) ProcessErrorResult
}
// HandshakeInformation contains information extracted from a MongoDB connection handshake. This is a helper type that
// augments description.Server by also tracking server connection ID and authentication-related fields. We use this type
// rather than adding authentication-related fields to description.Server to avoid retaining sensitive information in a
// user-facing type. The server connection ID is stored in this type because unlike description.Server, all handshakes are
// correlated with a single network connection.
type HandshakeInformation struct {
Description description.Server
SpeculativeAuthenticate bsoncore.Document
ServerConnectionID *int64
SaslSupportedMechs []string
}
// Handshaker is the interface implemented by types that can perform a MongoDB
// handshake over a provided driver.Connection. This is used during connection
// initialization. Implementations must be goroutine safe.
type Handshaker interface {
GetHandshakeInformation(context.Context, address.Address, *mnet.Connection) (HandshakeInformation, error)
FinishHandshake(context.Context, *mnet.Connection) error
}
// SingleServerDeployment is an implementation of Deployment that always returns a single server.
type SingleServerDeployment struct{ Server }
var _ Deployment = SingleServerDeployment{}
// SelectServer implements the Deployment interface. This method does not use the
// description.SelectedServer provided and instead returns the embedded Server.
func (ssd SingleServerDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
return ssd.Server, nil
}
// Kind implements the Deployment interface. It always returns description.TopologyKindSingle.
func (SingleServerDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle }
// GetServerSelectionTimeout returns zero as a server selection timeout is not
// applicable for single server deployments.
func (SingleServerDeployment) GetServerSelectionTimeout() time.Duration {
return 0
}
// SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This
// implementation should only be used for connection handshakes and server heartbeats as it does not implement
// ErrorProcessor, which is necessary for application operations.
type SingleConnectionDeployment struct{ C *mnet.Connection }
var (
_ Deployment = SingleConnectionDeployment{}
_ Server = SingleConnectionDeployment{}
)
// SelectServer implements the Deployment interface. This method does not use the
// description.SelectedServer provided and instead returns itself. The Connections returned from the
// Connection method have a no-op Close method.
func (scd SingleConnectionDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
return scd, nil
}
// GetServerSelectionTimeout returns zero as a server selection timeout is not
// applicable for single connection deployment.
func (SingleConnectionDeployment) GetServerSelectionTimeout() time.Duration {
return 0
}
// Kind implements the Deployment interface. It always returns description.TopologyKindSingle.
func (SingleConnectionDeployment) Kind() description.TopologyKind {
return description.TopologyKindSingle
}
// Connection implements the Server interface. It always returns the embedded connection.
func (scd SingleConnectionDeployment) Connection(context.Context) (*mnet.Connection, error) {
return scd.C, nil
}
// RTTMonitor implements the driver.Server interface.
func (scd SingleConnectionDeployment) RTTMonitor() RTTMonitor {
return &csot.ZeroRTTMonitor{}
}
// TODO(GODRIVER-617): We can likely use 1 type for both the Type and the RetryMode by using 2 bits for the mode and 1
// TODO bit for the type. Although in the practical sense, we might not want to do that since the type of retryability
// TODO is tied to the operation itself and isn't going change, e.g. and insert operation will always be a write,
// TODO however some operations are both reads and writes, for instance aggregate is a read but with a $out parameter
// TODO it's a write.
// Type specifies whether an operation is a read, write, or unknown.
type Type uint
// THese are the availables types of Type.
const (
_ Type = iota
Write
Read
)
// RetryMode specifies the way that retries are handled for retryable operations.
type RetryMode uint
// These are the modes available for retrying. Note that if Timeout is specified on the Client, the
// operation will automatically retry as many times as possible within the context's deadline
// unless RetryNone is used.
const (
// RetryNone disables retrying.
RetryNone RetryMode = iota
// RetryOnce will enable retrying the entire operation once if Timeout is not specified.
RetryOnce
// RetryOncePerCommand will enable retrying each command associated with an operation if Timeout
// is not specified. For example, if an insert is batch split into 4 commands then each of
// those commands is eligible for one retry.
RetryOncePerCommand
// RetryContext will enable retrying until the context.Context's deadline is exceeded or it is
// cancelled.
RetryContext
)
// Enabled returns if this RetryMode enables retrying.
func (rm RetryMode) Enabled() bool {
return rm == RetryOnce || rm == RetryOncePerCommand || rm == RetryContext
}

View File

@@ -0,0 +1,568 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
)
// LegacyNotPrimaryErrMsg is the error message that older MongoDB servers (see
// SERVER-50412 for versions) return when a write operation is erroneously sent
// to a non-primary node.
const LegacyNotPrimaryErrMsg = "not master"
var (
retryableCodes = []int32{
6, // HostUnreachable
7, // HostNotFound
89, // NetworkTimeout
91, // ShutdownInProgress
134, // ReadConcernMajorityNotAvailableYet
189, // PrimarySteppedDown
262, // ExceededTimeLimit
9001, // SocketException
10107, // NotWritablePrimary
11600, // InterruptedAtShutdown
11602, // InterruptedDueToReplStateChange
13435, // NotPrimaryNoSecondaryOk
13436, // NotPrimaryOrSecondary
}
nodeIsRecoveringCodes = []int32{11600, 11602, 13436, 189, 91}
notPrimaryCodes = []int32{10107, 13435, 10058}
nodeIsShuttingDownCodes = []int32{11600, 91}
unknownReplWriteConcernCode = int32(79)
unsatisfiableWriteConcernCode = int32(100)
)
var (
// UnknownTransactionCommitResult is an error label for unknown transaction commit results.
UnknownTransactionCommitResult = "UnknownTransactionCommitResult"
// TransientTransactionError is an error label for transient errors with transactions.
TransientTransactionError = "TransientTransactionError"
// NetworkError is an error label for network errors.
NetworkError = "NetworkError"
// RetryableWriteError is an error label for retryable write errors.
RetryableWriteError = "RetryableWriteError"
// NoWritesPerformed is an error label indicated that no writes were performed for an operation.
NoWritesPerformed = "NoWritesPerformed"
// ErrCursorNotFound is the cursor not found error for legacy find operations.
ErrCursorNotFound = errors.New("cursor not found")
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
// write concern.
ErrUnacknowledgedWrite = errors.New("unacknowledged write")
// ErrUnsupportedStorageEngine is returned when a retryable write is attempted against a server
// that uses a storage engine that does not support retryable writes
ErrUnsupportedStorageEngine = errors.New("this MongoDB deployment does not support retryable writes. Please add retryWrites=false to your connection string")
// ErrDeadlineWouldBeExceeded is returned when a Timeout set on an operation
// would be exceeded if the operation were sent to the server. It wraps
// context.DeadlineExceeded.
ErrDeadlineWouldBeExceeded = fmt.Errorf(
"operation not sent to server, as Timeout would be exceeded: %w",
context.DeadlineExceeded)
)
// QueryFailureError is an error representing a command failure as a document.
type QueryFailureError struct {
Message string
Response bsoncore.Document
Wrapped error
}
// Error implements the error interface.
func (e QueryFailureError) Error() string {
return fmt.Sprintf("%s: %v", e.Message, e.Response)
}
// Unwrap returns the underlying error.
func (e QueryFailureError) Unwrap() error {
return e.Wrapped
}
// ResponseError is an error parsing the response to a command.
type ResponseError struct {
Message string
Wrapped error
}
// NewCommandResponseError creates a CommandResponseError.
func NewCommandResponseError(msg string, err error) ResponseError {
return ResponseError{Message: msg, Wrapped: err}
}
// Error implements the error interface.
func (e ResponseError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("%s: %s", e.Message, e.Wrapped)
}
return e.Message
}
// WriteCommandError is an error for a write command.
type WriteCommandError struct {
WriteConcernError *WriteConcernError
WriteErrors WriteErrors
Labels []string
Raw bsoncore.Document
}
// UnsupportedStorageEngine returns whether or not the WriteCommandError comes from a retryable write being attempted
// against a server that has a storage engine where they are not supported
func (wce WriteCommandError) UnsupportedStorageEngine() bool {
for _, writeError := range wce.WriteErrors {
if writeError.Code == 20 && strings.HasPrefix(strings.ToLower(writeError.Message), "transaction numbers") {
return true
}
}
return false
}
func (wce WriteCommandError) Error() string {
var buf bytes.Buffer
fmt.Fprint(&buf, "write command error: [")
fmt.Fprintf(&buf, "{%s}, ", wce.WriteErrors)
fmt.Fprintf(&buf, "{%s}]", wce.WriteConcernError)
return buf.String()
}
// Retryable returns true if the error is retryable
func (wce WriteCommandError) Retryable(serverKind description.ServerKind, wireVersion *description.VersionRange) bool {
for _, label := range wce.Labels {
if label == RetryableWriteError {
return true
}
}
if wireVersion != nil && wireVersion.Max >= 9 {
return false
}
if wce.WriteConcernError == nil {
return false
}
return wce.WriteConcernError.Retryable(serverKind, wireVersion)
}
// HasErrorLabel returns true if the error contains the specified label.
func (wce WriteCommandError) HasErrorLabel(label string) bool {
for _, l := range wce.Labels {
if l == label {
return true
}
}
return false
}
// WriteConcernError is a write concern failure that occurred as a result of a
// write operation.
type WriteConcernError struct {
Name string
Code int64
Message string
Details bsoncore.Document
Labels []string
TopologyVersion *description.TopologyVersion
Raw bsoncore.Document
}
func (wce WriteConcernError) Error() string {
if wce.Name != "" {
return fmt.Sprintf("(%v) %v", wce.Name, wce.Message)
}
return wce.Message
}
// Retryable returns true if the error is retryable
func (wce WriteConcernError) Retryable(serverKind description.ServerKind, wireVersion *description.VersionRange) bool {
if serverKind == description.ServerKindMongos && wireVersion.Max < 9 {
// For a pre-4.4 mongos response, we can trust that mongos will have already
// retried the operation if necessary. Drivers should not retry to avoid
// "excessive retrying".
return false
}
for _, code := range retryableCodes {
if wce.Code == int64(code) {
return true
}
}
return false
}
// NodeIsRecovering returns true if this error is a node is recovering error.
func (wce WriteConcernError) NodeIsRecovering() bool {
for _, code := range nodeIsRecoveringCodes {
if wce.Code == int64(code) {
return true
}
}
hasNoCode := wce.Code == 0
return hasNoCode && strings.Contains(wce.Message, "node is recovering")
}
// NodeIsShuttingDown returns true if this error is a node is shutting down error.
func (wce WriteConcernError) NodeIsShuttingDown() bool {
for _, code := range nodeIsShuttingDownCodes {
if wce.Code == int64(code) {
return true
}
}
hasNoCode := wce.Code == 0
return hasNoCode && strings.Contains(wce.Message, "node is shutting down")
}
// NotPrimary returns true if this error is a not primary error.
func (wce WriteConcernError) NotPrimary() bool {
for _, code := range notPrimaryCodes {
if wce.Code == int64(code) {
return true
}
}
hasNoCode := wce.Code == 0
return hasNoCode && strings.Contains(wce.Message, LegacyNotPrimaryErrMsg)
}
// WriteError is a non-write concern failure that occurred as a result of a write
// operation.
type WriteError struct {
Index int64
Code int64
Message string
Details bsoncore.Document
Raw bsoncore.Document
}
func (we WriteError) Error() string { return we.Message }
// WriteErrors is a group of non-write concern failures that occurred as a result
// of a write operation.
type WriteErrors []WriteError
func (we WriteErrors) Error() string {
var buf bytes.Buffer
fmt.Fprint(&buf, "write errors: [")
for idx, err := range we {
if idx != 0 {
fmt.Fprintf(&buf, ", ")
}
fmt.Fprintf(&buf, "{%s}", err)
}
fmt.Fprint(&buf, "]")
return buf.String()
}
// Error is a command execution error from the database.
type Error struct {
Code int32
Message string
Labels []string
Name string
Wrapped error
TopologyVersion *description.TopologyVersion
Raw bsoncore.Document
}
// UnsupportedStorageEngine returns whether e came as a result of an unsupported storage engine
func (e Error) UnsupportedStorageEngine() bool {
return e.Code == 20 && strings.HasPrefix(strings.ToLower(e.Message), "transaction numbers")
}
// Error implements the error interface.
func (e Error) Error() string {
var msg string
if e.Name != "" {
msg = fmt.Sprintf("(%v)", e.Name)
}
msg += " " + e.Message
if e.Wrapped != nil {
msg += ": " + e.Wrapped.Error()
}
return msg
}
// Unwrap returns the underlying error.
func (e Error) Unwrap() error {
return e.Wrapped
}
// HasErrorLabel returns true if the error contains the specified label.
func (e Error) HasErrorLabel(label string) bool {
for _, l := range e.Labels {
if l == label {
return true
}
}
return false
}
// RetryableRead returns true if the error is retryable for a read operation
func (e Error) RetryableRead() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
for _, code := range retryableCodes {
if e.Code == code {
return true
}
}
return false
}
// RetryableWrite returns true if the error is retryable for a write operation
func (e Error) RetryableWrite(wireVersion *description.VersionRange) bool {
for _, label := range e.Labels {
if label == NetworkError || label == RetryableWriteError {
return true
}
}
if wireVersion != nil && wireVersion.Max >= 9 {
return false
}
for _, code := range retryableCodes {
if e.Code == code {
return true
}
}
return false
}
// NetworkError returns true if the error is a network error.
func (e Error) NetworkError() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
return false
}
// NodeIsRecovering returns true if this error is a node is recovering error.
func (e Error) NodeIsRecovering() bool {
for _, code := range nodeIsRecoveringCodes {
if e.Code == code {
return true
}
}
hasNoCode := e.Code == 0
return hasNoCode && strings.Contains(e.Message, "node is recovering")
}
// NodeIsShuttingDown returns true if this error is a node is shutting down error.
func (e Error) NodeIsShuttingDown() bool {
for _, code := range nodeIsShuttingDownCodes {
if e.Code == code {
return true
}
}
hasNoCode := e.Code == 0
return hasNoCode && strings.Contains(e.Message, "node is shutting down")
}
// NotPrimary returns true if this error is a not primary error.
func (e Error) NotPrimary() bool {
for _, code := range notPrimaryCodes {
if e.Code == code {
return true
}
}
hasNoCode := e.Code == 0
return hasNoCode && strings.Contains(e.Message, LegacyNotPrimaryErrMsg)
}
// NamespaceNotFound returns true if this errors is a NamespaceNotFound error.
func (e Error) NamespaceNotFound() bool {
return e.Code == 26 || e.Message == "ns not found"
}
// ExtractErrorFromServerResponse extracts an error from a server response bsoncore.Document
// if there is one. Also used in testing for SDAM.
func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
var errmsg, codeName string
var code int32
var labels []string
var ok bool
var tv *description.TopologyVersion
var wcError WriteCommandError
elems, err := doc.Elements()
if err != nil {
return err
}
for _, elem := range elems {
switch elem.Key() {
case "ok":
switch elem.Value().Type {
case bsoncore.TypeInt32:
if elem.Value().Int32() == 1 {
ok = true
}
case bsoncore.TypeInt64:
if elem.Value().Int64() == 1 {
ok = true
}
case bsoncore.TypeDouble:
if elem.Value().Double() == 1 {
ok = true
}
case bsoncore.TypeBoolean:
if elem.Value().Boolean() {
ok = true
}
}
case "errmsg":
if str, okay := elem.Value().StringValueOK(); okay {
errmsg = str
}
case "codeName":
if str, okay := elem.Value().StringValueOK(); okay {
codeName = str
}
case "code":
if c, okay := elem.Value().Int32OK(); okay {
code = c
}
case "errorLabels":
if arr, okay := elem.Value().ArrayOK(); okay {
vals, err := arr.Values()
if err != nil {
continue
}
for _, val := range vals {
if str, ok := val.StringValueOK(); ok {
labels = append(labels, str)
}
}
}
case "writeErrors":
arr, exists := elem.Value().ArrayOK()
if !exists {
break
}
vals, err := arr.Values()
if err != nil {
continue
}
for _, val := range vals {
var we WriteError
doc, exists := val.DocumentOK()
if !exists {
continue
}
if index, exists := doc.Lookup("index").AsInt64OK(); exists {
we.Index = index
}
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
we.Code = code
}
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
we.Message = msg
}
if info, exists := doc.Lookup("errInfo").DocumentOK(); exists {
we.Details = make([]byte, len(info))
copy(we.Details, info)
}
we.Raw = doc
wcError.WriteErrors = append(wcError.WriteErrors, we)
}
case "writeConcernError":
doc, exists := elem.Value().DocumentOK()
if !exists {
break
}
wcError.WriteConcernError = new(WriteConcernError)
wcError.WriteConcernError.Raw = doc
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
wcError.WriteConcernError.Code = code
}
if name, exists := doc.Lookup("codeName").StringValueOK(); exists {
wcError.WriteConcernError.Name = name
}
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
wcError.WriteConcernError.Message = msg
}
if info, exists := doc.Lookup("errInfo").DocumentOK(); exists {
wcError.WriteConcernError.Details = make([]byte, len(info))
copy(wcError.WriteConcernError.Details, info)
}
if errLabels, exists := doc.Lookup("errorLabels").ArrayOK(); exists {
vals, err := errLabels.Values()
if err != nil {
continue
}
for _, val := range vals {
if str, ok := val.StringValueOK(); ok {
labels = append(labels, str)
}
}
}
case "topologyVersion":
doc, ok := elem.Value().DocumentOK()
if !ok {
break
}
version, err := driverutil.NewTopologyVersion(bson.Raw(doc))
if err == nil {
tv = version
}
}
}
if !ok {
if errmsg == "" {
errmsg = "command failed"
}
err := Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
TopologyVersion: tv,
Raw: doc,
}
// If we get a MaxTimeMSExpired error, assume that the error was caused
// by setting "maxTimeMS" on the command based on the context deadline
// or on "timeoutMS". In that case, make the error wrap
// context.DeadlineExceeded so that users can always check
//
// errors.Is(err, context.DeadlineExceeded)
//
// for either client-side or server-side timeouts.
if err.Code == 50 {
err.Wrapped = context.DeadlineExceeded
}
return err
}
if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil {
wcError.Labels = labels
if wcError.WriteConcernError != nil {
wcError.WriteConcernError.TopologyVersion = tv
}
wcError.Raw = doc
return wcError
}
return nil
}

View File

@@ -0,0 +1,23 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
// LegacyOperationKind indicates if an operation is a legacy find, getMore, or killCursors. This is used
// in Operation.Execute, which will create legacy OP_QUERY, OP_GET_MORE, or OP_KILL_CURSORS instead
// of sending them as a command.
type LegacyOperationKind uint
// These constants represent the three different kinds of legacy operations.
const (
LegacyNone LegacyOperationKind = iota
LegacyFind
LegacyGetMore
LegacyKillCursors
LegacyListCollections
LegacyListIndexes
LegacyHandshake
)

View File

@@ -0,0 +1,121 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package mnet
import (
"context"
"io"
"go.mongodb.org/mongo-driver/v2/mongo/address"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
)
// ReadWriteCloser represents a Connection where server operations
// can read from, written to, and closed.
type ReadWriteCloser interface {
Read(ctx context.Context) ([]byte, error)
Write(ctx context.Context, wm []byte) error
io.Closer
}
// Describer represents a Connection that can be described.
type Describer interface {
Description() description.Server
ID() string
ServerConnectionID() *int64
DriverConnectionID() int64
Address() address.Address
Stale() bool
OIDCTokenGenID() uint64
SetOIDCTokenGenID(uint64)
}
// Streamer represents a Connection that supports streaming wire protocol
// messages using the moreToCome and exhaustAllowed flags.
//
// The SetStreaming and CurrentlyStreaming functions correspond to the
// moreToCome flag on server responses. If a response has moreToCome set,
// SetStreaming(true) will be called and CurrentlyStreaming() should return
// true.
//
// CanStream corresponds to the exhaustAllowed flag. The operations layer will
// set exhaustAllowed on outgoing wire messages to inform the server that the
// driver supports streaming.
type Streamer interface {
SetStreaming(bool)
CurrentlyStreaming() bool
SupportsStreaming() bool
}
// Compressor is an interface used to compress wire messages. If a Connection
// supports compression it should implement this interface as well. The
// CompressWireMessage method will be called during the execution of an
// operation if the wire message is allowed to be compressed.
type Compressor interface {
CompressWireMessage(src, dst []byte) ([]byte, error)
}
// Pinner represents a Connection that can be pinned by one or more cursors or
// transactions. Implementations of this interface should maintain the following
// invariants:
//
// 1. Each Pin* call should increment the number of references for the
// connection.
// 2. Each Unpin* call should decrement the number of references for the
// connection.
// 3. Calls to Close() should be ignored until all resources have unpinned the
// connection.
type Pinner interface {
PinToCursor() error
PinToTransaction() error
UnpinFromCursor() error
UnpinFromTransaction() error
}
// Connection represents a connection to a MongoDB server.
type Connection struct {
ReadWriteCloser
Describer
Streamer
Compressor
Pinner
}
// NewConnection creates a new Connection with the provided component. This
// constructor returns a component that is already a Connection to avoid
// mis-asserting the composite interfaces.
func NewConnection(component interface {
ReadWriteCloser
Describer
},
) *Connection {
if _, ok := component.(*Connection); ok {
return component.(*Connection)
}
conn := &Connection{
ReadWriteCloser: component,
}
if describer, ok := component.(Describer); ok {
conn.Describer = describer
}
if streamer, ok := component.(Streamer); ok {
conn.Streamer = streamer
}
if compressor, ok := component.(Compressor); ok {
conn.Compressor = compressor
}
if pinner, ok := component.(Pinner); ok {
conn.Pinner = pinner
}
return conn
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build cse
package mongocrypt
/*
#include <stdlib.h>
#include <mongocrypt.h>
*/
import "C"
import (
"unsafe"
)
// binary is a wrapper type around a mongocrypt_binary_t*
type binary struct {
p *C.uint8_t
wrapped *C.mongocrypt_binary_t
}
// newBinary creates an empty binary instance.
func newBinary() *binary {
return &binary{
wrapped: C.mongocrypt_binary_new(),
}
}
// newBinaryFromBytes creates a binary instance from a byte buffer.
func newBinaryFromBytes(data []byte) *binary {
if len(data) == 0 {
return newBinary()
}
// TODO: Consider using runtime.Pinner to replace the C.CBytes after using go1.21.0.
addr := (*C.uint8_t)(C.CBytes(data)) // uint8_t*
dataLen := C.uint32_t(len(data)) // uint32_t
return &binary{
p: addr,
wrapped: C.mongocrypt_binary_new_from_data(addr, dataLen),
}
}
// toBytes converts the given binary instance to []byte.
func (b *binary) toBytes() []byte {
dataPtr := C.mongocrypt_binary_data(b.wrapped) // C.uint8_t*
dataLen := C.mongocrypt_binary_len(b.wrapped) // C.uint32_t
return C.GoBytes(unsafe.Pointer(dataPtr), C.int(dataLen))
}
// close cleans up any resources associated with the given binary instance.
func (b *binary) close() {
if b.p != nil {
C.free(unsafe.Pointer(b.p))
}
C.mongocrypt_binary_destroy(b.wrapped)
}

View File

@@ -0,0 +1,44 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build cse
package mongocrypt
// #include <mongocrypt.h>
import "C"
import (
"fmt"
)
// Error represents an error from an operation on a MongoCrypt instance.
type Error struct {
Code int32
Message string
}
// Error implements the error interface.
func (e Error) Error() string {
return fmt.Sprintf("mongocrypt error %d: %v", e.Code, e.Message)
}
// errorFromStatus builds a Error from a mongocrypt_status_t object.
func errorFromStatus(status *C.mongocrypt_status_t) error {
cCode := C.mongocrypt_status_code(status) // uint32_t
// mongocrypt_status_message takes uint32_t* as its second param to store the length of the returned string.
// pass nil because the length is handled by C.GoString
cMsg := C.mongocrypt_status_message(status, nil) // const char*
var msg string
if cMsg != nil {
msg = C.GoString(cMsg)
}
return Error{
Code: int32(cCode),
Message: msg,
}
}

View File

@@ -0,0 +1,20 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !cse
package mongocrypt
// Error represents an error from an operation on a MongoCrypt instance.
type Error struct {
Code int32
Message string
}
// Error implements the error interface
func (Error) Error() string {
panic(cseNotSupportedMsg)
}

View File

@@ -0,0 +1,586 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build cse
package mongocrypt
// #cgo linux solaris darwin pkg-config: libmongocrypt
// #cgo windows CFLAGS: -I"c:/libmongocrypt/include"
// #cgo windows LDFLAGS: -lmongocrypt -Lc:/libmongocrypt/bin
// #include <mongocrypt.h>
// #include <stdlib.h>
import "C"
import (
"context"
"errors"
"fmt"
"net/http"
"time"
"unsafe"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/httputil"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth/creds"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt/options"
)
type kmsProvider interface {
GetCredentialsDoc(context.Context) (bsoncore.Document, error)
}
type MongoCrypt struct {
wrapped *C.mongocrypt_t
kmsProviders map[string]kmsProvider
httpClient *http.Client
}
// Version returns the version string for the loaded libmongocrypt, or an empty string
// if libmongocrypt was not loaded.
func Version() string {
str := C.GoString(C.mongocrypt_version(nil))
return str
}
// NewMongoCrypt constructs a new MongoCrypt instance configured using the provided MongoCryptOptions.
func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
// create mongocrypt_t handle
wrapped := C.mongocrypt_new()
if wrapped == nil {
return nil, errors.New("could not create new mongocrypt object")
}
C.mongocrypt_setopt_retry_kms(wrapped, true)
httpClient := opts.HTTPClient
if httpClient == nil {
httpClient = httputil.DefaultHTTPClient
}
kmsProviders := make(map[string]kmsProvider)
if needsKmsProvider(opts.KmsProviders, "gcp") {
kmsProviders["gcp"] = creds.NewGCPCredentialProvider(httpClient)
}
if needsKmsProvider(opts.KmsProviders, "aws") {
kmsProviders["aws"] = creds.NewAWSCredentialProvider(httpClient)
}
if needsKmsProvider(opts.KmsProviders, "azure") {
kmsProviders["azure"] = creds.NewAzureCredentialProvider(httpClient)
}
crypt := &MongoCrypt{
wrapped: wrapped,
kmsProviders: kmsProviders,
httpClient: httpClient,
}
// set options in mongocrypt
if err := crypt.setProviderOptions(opts.KmsProviders); err != nil {
return nil, err
}
if err := crypt.setLocalSchemaMap(opts.LocalSchemaMap); err != nil {
return nil, err
}
if err := crypt.setEncryptedFieldsMap(opts.EncryptedFieldsMap); err != nil {
return nil, err
}
if opts.BypassQueryAnalysis {
C.mongocrypt_setopt_bypass_query_analysis(crypt.wrapped)
}
var keyExpirationMs uint64 = 60_000 // 60,000 ms
if opts.KeyExpiration != nil {
if *opts.KeyExpiration <= 0 {
keyExpirationMs = 0
} else {
// find the ceiling integer millisecond for the expiration
keyExpirationMs = uint64((*opts.KeyExpiration + time.Millisecond - 1) / time.Millisecond)
}
}
C.mongocrypt_setopt_key_expiration(crypt.wrapped, C.uint64_t(keyExpirationMs))
// If loading the crypt_shared library isn't disabled, set the default library search path "$SYSTEM"
// and set a library override path if one was provided.
if !opts.CryptSharedLibDisabled {
systemStr := C.CString("$SYSTEM")
defer C.free(unsafe.Pointer(systemStr))
C.mongocrypt_setopt_append_crypt_shared_lib_search_path(crypt.wrapped, systemStr)
if opts.CryptSharedLibOverridePath != "" {
cryptSharedLibOverridePathStr := C.CString(opts.CryptSharedLibOverridePath)
defer C.free(unsafe.Pointer(cryptSharedLibOverridePathStr))
C.mongocrypt_setopt_set_crypt_shared_lib_path_override(crypt.wrapped, cryptSharedLibOverridePathStr)
}
}
C.mongocrypt_setopt_use_need_kms_credentials_state(crypt.wrapped)
// initialize handle
if !C.mongocrypt_init(crypt.wrapped) {
return nil, crypt.createErrorFromStatus()
}
return crypt, nil
}
// CreateEncryptionContext creates a Context to use for encryption.
func (m *MongoCrypt) CreateEncryptionContext(db string, cmd bsoncore.Document) (*Context, error) {
ctx := newContext(C.mongocrypt_ctx_new(m.wrapped))
if ctx.wrapped == nil {
return nil, m.createErrorFromStatus()
}
cmdBinary := newBinaryFromBytes(cmd)
defer cmdBinary.close()
dbStr := C.CString(db)
defer C.free(unsafe.Pointer(dbStr))
if ok := C.mongocrypt_ctx_encrypt_init(ctx.wrapped, dbStr, C.int32_t(-1), cmdBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
return ctx, nil
}
// CreateDecryptionContext creates a Context to use for decryption.
func (m *MongoCrypt) CreateDecryptionContext(cmd bsoncore.Document) (*Context, error) {
ctx := newContext(C.mongocrypt_ctx_new(m.wrapped))
if ctx.wrapped == nil {
return nil, m.createErrorFromStatus()
}
cmdBinary := newBinaryFromBytes(cmd)
defer cmdBinary.close()
if ok := C.mongocrypt_ctx_decrypt_init(ctx.wrapped, cmdBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
return ctx, nil
}
// lookupString returns a string for the value corresponding to the given key in the document.
// if the key does not exist or the value is not a string, the empty string is returned.
func lookupString(doc bsoncore.Document, key string) string {
strVal, _ := doc.Lookup(key).StringValueOK()
return strVal
}
func setAltName(ctx *Context, altName string) error {
// create document {"keyAltName": keyAltName}
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "keyAltName", altName)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
keyAltBinary := newBinaryFromBytes(doc)
defer keyAltBinary.close()
if ok := C.mongocrypt_ctx_setopt_key_alt_name(ctx.wrapped, keyAltBinary.wrapped); !ok {
return ctx.createErrorFromStatus()
}
return nil
}
func setKeyMaterial(ctx *Context, keyMaterial []byte) error {
// Create document {"keyMaterial": keyMaterial} using the generic binary sybtype 0x00.
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendBinaryElement(doc, "keyMaterial", 0x00, keyMaterial)
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
if err != nil {
return err
}
keyMaterialBinary := newBinaryFromBytes(doc)
defer keyMaterialBinary.close()
if ok := C.mongocrypt_ctx_setopt_key_material(ctx.wrapped, keyMaterialBinary.wrapped); !ok {
return ctx.createErrorFromStatus()
}
return nil
}
func rewrapDataKey(ctx *Context, filter []byte) error {
filterBinary := newBinaryFromBytes(filter)
defer filterBinary.close()
if ok := C.mongocrypt_ctx_rewrap_many_datakey_init(ctx.wrapped, filterBinary.wrapped); !ok {
return ctx.createErrorFromStatus()
}
return nil
}
// CreateDataKeyContext creates a Context to use for creating a data key.
func (m *MongoCrypt) CreateDataKeyContext(kmsProvider string, opts *options.DataKeyOptions) (*Context, error) {
ctx := newContext(C.mongocrypt_ctx_new(m.wrapped))
if ctx.wrapped == nil {
return nil, m.createErrorFromStatus()
}
// Create a masterKey document of the form { "provider": <provider string>, other options... }.
var masterKey bsoncore.Document
switch {
case opts.MasterKey != nil:
// The original key passed into the top-level API was already transformed into a raw BSON document and passed
// down to here, so we can modify it without copying. Remove the terminating byte to add the "provider" field.
masterKey = opts.MasterKey[:len(opts.MasterKey)-1]
masterKey = bsoncore.AppendStringElement(masterKey, "provider", kmsProvider)
masterKey, _ = bsoncore.AppendDocumentEnd(masterKey, 0)
default:
masterKey = bsoncore.NewDocumentBuilder().AppendString("provider", kmsProvider).Build()
}
masterKeyBinary := newBinaryFromBytes(masterKey)
defer masterKeyBinary.close()
if ok := C.mongocrypt_ctx_setopt_key_encryption_key(ctx.wrapped, masterKeyBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
for _, altName := range opts.KeyAltNames {
if err := setAltName(ctx, altName); err != nil {
return nil, err
}
}
if opts.KeyMaterial != nil {
if err := setKeyMaterial(ctx, opts.KeyMaterial); err != nil {
return nil, err
}
}
if ok := C.mongocrypt_ctx_datakey_init(ctx.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
return ctx, nil
}
const (
IndexTypeUnindexed = 1
IndexTypeIndexed = 2
)
// createExplicitEncryptionContext creates an explicit encryption context.
func (m *MongoCrypt) createExplicitEncryptionContext(opts *options.ExplicitEncryptionOptions) (*Context, error) {
ctx := newContext(C.mongocrypt_ctx_new(m.wrapped))
if ctx.wrapped == nil {
return nil, m.createErrorFromStatus()
}
if opts.KeyID != nil {
keyIDBinary := newBinaryFromBytes(opts.KeyID.Data)
defer keyIDBinary.close()
if ok := C.mongocrypt_ctx_setopt_key_id(ctx.wrapped, keyIDBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
}
if opts.KeyAltName != nil {
if err := setAltName(ctx, *opts.KeyAltName); err != nil {
return nil, err
}
}
if opts.RangeOptions != nil {
idx, mongocryptDoc := bsoncore.AppendDocumentStart(nil)
if opts.RangeOptions.Min != nil {
mongocryptDoc = bsoncore.AppendValueElement(mongocryptDoc, "min", *opts.RangeOptions.Min)
}
if opts.RangeOptions.Max != nil {
mongocryptDoc = bsoncore.AppendValueElement(mongocryptDoc, "max", *opts.RangeOptions.Max)
}
if opts.RangeOptions.Precision != nil {
mongocryptDoc = bsoncore.AppendInt32Element(mongocryptDoc, "precision", *opts.RangeOptions.Precision)
}
if opts.RangeOptions.Sparsity != nil {
mongocryptDoc = bsoncore.AppendInt64Element(mongocryptDoc, "sparsity", *opts.RangeOptions.Sparsity)
}
if opts.RangeOptions.TrimFactor != nil {
mongocryptDoc = bsoncore.AppendInt32Element(mongocryptDoc, "trimFactor", *opts.RangeOptions.TrimFactor)
}
mongocryptDoc, err := bsoncore.AppendDocumentEnd(mongocryptDoc, idx)
if err != nil {
return nil, err
}
mongocryptBinary := newBinaryFromBytes(mongocryptDoc)
defer mongocryptBinary.close()
if ok := C.mongocrypt_ctx_setopt_algorithm_range(ctx.wrapped, mongocryptBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
}
if opts.TextOptions != nil {
idx, mongocryptDoc := bsoncore.AppendDocumentStart(nil)
if opts.TextOptions.Substring != nil {
substringIdx, substringDoc := bsoncore.AppendDocumentStart(nil)
substringDoc = bsoncore.AppendInt32Element(substringDoc, "strMaxLength", opts.TextOptions.Substring.StrMaxLength)
substringDoc = bsoncore.AppendInt32Element(substringDoc, "strMinQueryLength", opts.TextOptions.Substring.StrMinQueryLength)
substringDoc = bsoncore.AppendInt32Element(substringDoc, "strMaxQueryLength", opts.TextOptions.Substring.StrMaxQueryLength)
substringDoc, err := bsoncore.AppendDocumentEnd(substringDoc, substringIdx)
if err != nil {
return nil, fmt.Errorf("error building substring doc: %w", err)
}
mongocryptDoc = bsoncore.AppendDocumentElement(mongocryptDoc, "substring", substringDoc)
}
if opts.TextOptions.Prefix != nil {
prefixIdx, prefixDoc := bsoncore.AppendDocumentStart(nil)
prefixDoc = bsoncore.AppendInt32Element(prefixDoc, "strMinQueryLength", opts.TextOptions.Prefix.StrMinQueryLength)
prefixDoc = bsoncore.AppendInt32Element(prefixDoc, "strMaxQueryLength", opts.TextOptions.Prefix.StrMaxQueryLength)
prefixDoc, err := bsoncore.AppendDocumentEnd(prefixDoc, prefixIdx)
if err != nil {
return nil, fmt.Errorf("error building prefix doc: %w", err)
}
mongocryptDoc = bsoncore.AppendDocumentElement(mongocryptDoc, "prefix", prefixDoc)
}
if opts.TextOptions.Suffix != nil {
suffixIdx, suffixDoc := bsoncore.AppendDocumentStart(nil)
suffixDoc = bsoncore.AppendInt32Element(suffixDoc, "strMinQueryLength", opts.TextOptions.Suffix.StrMinQueryLength)
suffixDoc = bsoncore.AppendInt32Element(suffixDoc, "strMaxQueryLength", opts.TextOptions.Suffix.StrMaxQueryLength)
suffixDoc, err := bsoncore.AppendDocumentEnd(suffixDoc, suffixIdx)
if err != nil {
return nil, fmt.Errorf("error building suffix doc: %w", err)
}
mongocryptDoc = bsoncore.AppendDocumentElement(mongocryptDoc, "suffix", suffixDoc)
}
mongocryptDoc = bsoncore.AppendBooleanElement(mongocryptDoc, "caseSensitive", opts.TextOptions.CaseSensitive)
mongocryptDoc = bsoncore.AppendBooleanElement(mongocryptDoc, "diacriticSensitive", opts.TextOptions.DiacriticSensitive)
mongocryptDoc, err := bsoncore.AppendDocumentEnd(mongocryptDoc, idx)
if err != nil {
return nil, fmt.Errorf("error building text options doc: %w", err)
}
mongocryptBinary := newBinaryFromBytes(mongocryptDoc)
defer mongocryptBinary.close()
if ok := C.mongocrypt_ctx_setopt_algorithm_text(ctx.wrapped, mongocryptBinary.wrapped); !ok {
return nil, fmt.Errorf("error setting text algorithm option: %w", ctx.createErrorFromStatus())
}
}
algoStr := C.CString(opts.Algorithm)
defer C.free(unsafe.Pointer(algoStr))
if ok := C.mongocrypt_ctx_setopt_algorithm(ctx.wrapped, algoStr, -1); !ok {
return nil, ctx.createErrorFromStatus()
}
if opts.QueryType != "" {
queryStr := C.CString(opts.QueryType)
defer C.free(unsafe.Pointer(queryStr))
if ok := C.mongocrypt_ctx_setopt_query_type(ctx.wrapped, queryStr, -1); !ok {
return nil, ctx.createErrorFromStatus()
}
}
if opts.ContentionFactor != nil {
if ok := C.mongocrypt_ctx_setopt_contention_factor(ctx.wrapped, C.int64_t(*opts.ContentionFactor)); !ok {
return nil, ctx.createErrorFromStatus()
}
}
return ctx, nil
}
// CreateExplicitEncryptionContext creates a Context to use for explicit encryption.
func (m *MongoCrypt) CreateExplicitEncryptionContext(doc bsoncore.Document, opts *options.ExplicitEncryptionOptions) (*Context, error) {
ctx, err := m.createExplicitEncryptionContext(opts)
if err != nil {
return ctx, err
}
docBinary := newBinaryFromBytes(doc)
defer docBinary.close()
if ok := C.mongocrypt_ctx_explicit_encrypt_init(ctx.wrapped, docBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
return ctx, nil
}
// CreateExplicitEncryptionExpressionContext creates a Context to use for explicit encryption of an expression.
func (m *MongoCrypt) CreateExplicitEncryptionExpressionContext(doc bsoncore.Document, opts *options.ExplicitEncryptionOptions) (*Context, error) {
ctx, err := m.createExplicitEncryptionContext(opts)
if err != nil {
return ctx, err
}
docBinary := newBinaryFromBytes(doc)
defer docBinary.close()
if ok := C.mongocrypt_ctx_explicit_encrypt_expression_init(ctx.wrapped, docBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
return ctx, nil
}
// CreateExplicitDecryptionContext creates a Context to use for explicit decryption.
func (m *MongoCrypt) CreateExplicitDecryptionContext(doc bsoncore.Document) (*Context, error) {
ctx := newContext(C.mongocrypt_ctx_new(m.wrapped))
if ctx.wrapped == nil {
return nil, m.createErrorFromStatus()
}
docBinary := newBinaryFromBytes(doc)
defer docBinary.close()
if ok := C.mongocrypt_ctx_explicit_decrypt_init(ctx.wrapped, docBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
return ctx, nil
}
// CryptSharedLibVersion returns the version number for the loaded crypt_shared library, or 0 if the
// crypt_shared library was not loaded.
func (m *MongoCrypt) CryptSharedLibVersion() uint64 {
return uint64(C.mongocrypt_crypt_shared_lib_version(m.wrapped))
}
// CryptSharedLibVersionString returns the version string for the loaded crypt_shared library, or an
// empty string if the crypt_shared library was not loaded.
func (m *MongoCrypt) CryptSharedLibVersionString() string {
// Pass in a pointer for "len", but ignore the value because C.GoString can determine the string
// length without it.
len := C.uint(0)
str := C.GoString(C.mongocrypt_crypt_shared_lib_version_string(m.wrapped, &len))
return str
}
// Close cleans up any resources associated with the given MongoCrypt instance.
func (m *MongoCrypt) Close() {
C.mongocrypt_destroy(m.wrapped)
if m.httpClient == httputil.DefaultHTTPClient {
httputil.CloseIdleHTTPConnections(m.httpClient)
}
}
// RewrapDataKeyContext create a Context to use for rewrapping a data key.
func (m *MongoCrypt) RewrapDataKeyContext(filter []byte, opts *options.RewrapManyDataKeyOptions) (*Context, error) {
const masterKey = "masterKey"
const providerKey = "provider"
ctx := newContext(C.mongocrypt_ctx_new(m.wrapped))
if ctx.wrapped == nil {
return nil, m.createErrorFromStatus()
}
if opts.MasterKey != nil && opts.Provider == nil {
// Provider is nil, but MasterKey is set. This is an error.
return nil, fmt.Errorf("expected 'Provider' to be set to identify type of 'MasterKey'")
}
if opts.Provider != nil {
// If a provider has been specified, create an encryption key document for creating a data key or for rewrapping
// datakeys. If a new provider is not specified, then the filter portion of this logic returns the data as it
// exists in the collection.
idx, mongocryptDoc := bsoncore.AppendDocumentStart(nil)
mongocryptDoc = bsoncore.AppendStringElement(mongocryptDoc, providerKey, *opts.Provider)
if opts.MasterKey != nil {
mongocryptDoc = opts.MasterKey[:len(opts.MasterKey)-1]
mongocryptDoc = bsoncore.AppendStringElement(mongocryptDoc, providerKey, *opts.Provider)
}
mongocryptDoc, err := bsoncore.AppendDocumentEnd(mongocryptDoc, idx)
if err != nil {
return nil, err
}
mongocryptBinary := newBinaryFromBytes(mongocryptDoc)
defer mongocryptBinary.close()
// Add new masterKey to the mongocrypt context.
if ok := C.mongocrypt_ctx_setopt_key_encryption_key(ctx.wrapped, mongocryptBinary.wrapped); !ok {
return nil, ctx.createErrorFromStatus()
}
}
return ctx, rewrapDataKey(ctx, filter)
}
func (m *MongoCrypt) setProviderOptions(kmsProviders bsoncore.Document) error {
providersBinary := newBinaryFromBytes(kmsProviders)
defer providersBinary.close()
if ok := C.mongocrypt_setopt_kms_providers(m.wrapped, providersBinary.wrapped); !ok {
return m.createErrorFromStatus()
}
return nil
}
// setLocalSchemaMap sets the local schema map in mongocrypt.
func (m *MongoCrypt) setLocalSchemaMap(schemaMap map[string]bsoncore.Document) error {
if len(schemaMap) == 0 {
return nil
}
// convert schema map to BSON document
schemaMapBSON, err := bson.Marshal(schemaMap)
if err != nil {
return fmt.Errorf("error marshalling SchemaMap: %v", err)
}
schemaMapBinary := newBinaryFromBytes(schemaMapBSON)
defer schemaMapBinary.close()
if ok := C.mongocrypt_setopt_schema_map(m.wrapped, schemaMapBinary.wrapped); !ok {
return m.createErrorFromStatus()
}
return nil
}
// setEncryptedFieldsMap sets the encryptedfields map in mongocrypt.
func (m *MongoCrypt) setEncryptedFieldsMap(encryptedfieldsMap map[string]bsoncore.Document) error {
if len(encryptedfieldsMap) == 0 {
return nil
}
// convert encryptedfields map to BSON document
encryptedfieldsMapBSON, err := bson.Marshal(encryptedfieldsMap)
if err != nil {
return fmt.Errorf("error marshalling EncryptedFieldsMap: %v", err)
}
encryptedfieldsMapBinary := newBinaryFromBytes(encryptedfieldsMapBSON)
defer encryptedfieldsMapBinary.close()
if ok := C.mongocrypt_setopt_encrypted_field_config_map(m.wrapped, encryptedfieldsMapBinary.wrapped); !ok {
return m.createErrorFromStatus()
}
return nil
}
// createErrorFromStatus creates a new Error based on the status of the MongoCrypt instance.
func (m *MongoCrypt) createErrorFromStatus() error {
status := C.mongocrypt_status_new()
defer C.mongocrypt_status_destroy(status)
C.mongocrypt_status(m.wrapped, status)
return errorFromStatus(status)
}
// needsKmsProvider returns true if provider was initially set to an empty document.
// An empty document signals the driver to fetch credentials.
func needsKmsProvider(kmsProviders bsoncore.Document, provider string) bool {
val, err := kmsProviders.LookupErr(provider)
if err != nil {
// KMS provider is not configured.
return false
}
doc, ok := val.DocumentOK()
// KMS provider is an empty document if the length is 5.
// An empty document contains 4 bytes of "\x00" and a null byte.
return ok && len(doc) == 5
}
// GetKmsProviders attempts to obtain credentials from environment.
// It is expected to be called when a libmongocrypt context is in the mongocrypt.NeedKmsCredentials state.
func (m *MongoCrypt) GetKmsProviders(ctx context.Context) (bsoncore.Document, error) {
builder := bsoncore.NewDocumentBuilder()
for k, p := range m.kmsProviders {
doc, err := p.GetCredentialsDoc(ctx)
if err != nil {
return nil, fmt.Errorf("unable to retrieve %s credentials: %w", k, err)
}
builder.AppendDocument(k, doc)
}
return builder.Build(), nil
}

View File

@@ -0,0 +1,115 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build cse
package mongocrypt
// #include <mongocrypt.h>
import "C"
import (
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// Context represents a mongocrypt_ctx_t handle
type Context struct {
wrapped *C.mongocrypt_ctx_t
}
// newContext creates a Context wrapper around the given C type.
func newContext(wrapped *C.mongocrypt_ctx_t) *Context {
return &Context{
wrapped: wrapped,
}
}
// State returns the current State of the Context.
func (c *Context) State() State {
return State(int(C.mongocrypt_ctx_state(c.wrapped)))
}
// NextOperation gets the document for the next database operation to run.
func (c *Context) NextOperation() (bsoncore.Document, error) {
opDocBinary := newBinary() // out param for mongocrypt_ctx_mongo_op to fill in operation
defer opDocBinary.close()
if ok := C.mongocrypt_ctx_mongo_op(c.wrapped, opDocBinary.wrapped); !ok {
return nil, c.createErrorFromStatus()
}
return opDocBinary.toBytes(), nil
}
// AddOperationResult feeds the result of a database operation to mongocrypt.
func (c *Context) AddOperationResult(result bsoncore.Document) error {
resultBinary := newBinaryFromBytes(result)
defer resultBinary.close()
if ok := C.mongocrypt_ctx_mongo_feed(c.wrapped, resultBinary.wrapped); !ok {
return c.createErrorFromStatus()
}
return nil
}
// CompleteOperation signals a database operation has been completed.
func (c *Context) CompleteOperation() error {
if ok := C.mongocrypt_ctx_mongo_done(c.wrapped); !ok {
return c.createErrorFromStatus()
}
return nil
}
// NextKmsContext returns the next KmsContext, or nil if there are no more.
func (c *Context) NextKmsContext() *KmsContext {
ctx := C.mongocrypt_ctx_next_kms_ctx(c.wrapped)
if ctx == nil {
return nil
}
return newKmsContext(ctx)
}
// FinishKmsContexts signals that all KMS contexts have been completed.
func (c *Context) FinishKmsContexts() error {
if ok := C.mongocrypt_ctx_kms_done(c.wrapped); !ok {
return c.createErrorFromStatus()
}
return nil
}
// Finish performs the final operations for the context and returns the resulting document.
func (c *Context) Finish() (bsoncore.Document, error) {
docBinary := newBinary() // out param for mongocrypt_ctx_finalize to fill in resulting document
defer docBinary.close()
if ok := C.mongocrypt_ctx_finalize(c.wrapped, docBinary.wrapped); !ok {
return nil, c.createErrorFromStatus()
}
return docBinary.toBytes(), nil
}
// Close cleans up any resources associated with the given Context instance.
func (c *Context) Close() {
C.mongocrypt_ctx_destroy(c.wrapped)
}
// createErrorFromStatus creates a new Error based on the status of the MongoCrypt instance.
func (c *Context) createErrorFromStatus() error {
status := C.mongocrypt_status_new()
defer C.mongocrypt_status_destroy(status)
C.mongocrypt_ctx_status(c.wrapped, status)
return errorFromStatus(status)
}
// ProvideKmsProviders provides the KMS providers when in the NeedKmsCredentials state.
func (c *Context) ProvideKmsProviders(kmsProviders bsoncore.Document) error {
kmsProvidersBinary := newBinaryFromBytes(kmsProviders)
defer kmsProvidersBinary.close()
if ok := C.mongocrypt_ctx_provide_kms_providers(c.wrapped, kmsProvidersBinary.wrapped); !ok {
return c.createErrorFromStatus()
}
return nil
}

View File

@@ -0,0 +1,61 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !cse
package mongocrypt
import (
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// Context represents a mongocrypt_ctx_t handle
type Context struct{}
// State returns the current State of the Context.
func (c *Context) State() State {
panic(cseNotSupportedMsg)
}
// NextOperation gets the document for the next database operation to run.
func (c *Context) NextOperation() (bsoncore.Document, error) {
panic(cseNotSupportedMsg)
}
// AddOperationResult feeds the result of a database operation to mongocrypt.
func (c *Context) AddOperationResult(bsoncore.Document) error {
panic(cseNotSupportedMsg)
}
// CompleteOperation signals a database operation has been completed.
func (c *Context) CompleteOperation() error {
panic(cseNotSupportedMsg)
}
// NextKmsContext returns the next KmsContext, or nil if there are no more.
func (c *Context) NextKmsContext() *KmsContext {
panic(cseNotSupportedMsg)
}
// FinishKmsContexts signals that all KMS contexts have been completed.
func (c *Context) FinishKmsContexts() error {
panic(cseNotSupportedMsg)
}
// Finish performs the final operations for the context and returns the resulting document.
func (c *Context) Finish() (bsoncore.Document, error) {
panic(cseNotSupportedMsg)
}
// Close cleans up any resources associated with the given Context instance.
func (c *Context) Close() {
panic(cseNotSupportedMsg)
}
// ProvideKmsProviders provides the KMS providers when in the NeedKmsCredentials state.
func (c *Context) ProvideKmsProviders(bsoncore.Document) error {
panic(cseNotSupportedMsg)
}

View File

@@ -0,0 +1,86 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build cse
package mongocrypt
// #include <mongocrypt.h>
import "C"
import "time"
// KmsContext represents a mongocrypt_kms_ctx_t handle.
type KmsContext struct {
wrapped *C.mongocrypt_kms_ctx_t
}
// newKmsContext creates a KmsContext wrapper around the given C type.
func newKmsContext(wrapped *C.mongocrypt_kms_ctx_t) *KmsContext {
return &KmsContext{
wrapped: wrapped,
}
}
// HostName gets the host name of the KMS.
func (kc *KmsContext) HostName() (string, error) {
var hostname *C.char // out param for mongocrypt function to fill in hostname
if ok := C.mongocrypt_kms_ctx_endpoint(kc.wrapped, &hostname); !ok {
return "", kc.createErrorFromStatus()
}
return C.GoString(hostname), nil
}
// KMSProvider gets the KMS provider of the KMS context.
func (kc *KmsContext) KMSProvider() string {
kmsProvider := C.mongocrypt_kms_ctx_get_kms_provider(kc.wrapped, nil)
return C.GoString(kmsProvider)
}
// Message returns the message to send to the KMS.
func (kc *KmsContext) Message() ([]byte, error) {
time.Sleep(time.Duration(C.mongocrypt_kms_ctx_usleep(kc.wrapped)) * time.Microsecond)
msgBinary := newBinary()
defer msgBinary.close()
if ok := C.mongocrypt_kms_ctx_message(kc.wrapped, msgBinary.wrapped); !ok {
return nil, kc.createErrorFromStatus()
}
return msgBinary.toBytes(), nil
}
// BytesNeeded returns the number of bytes that should be received from the KMS.
// After sending the message to the KMS, this message should be called in a loop until the number returned is 0.
func (kc *KmsContext) BytesNeeded() int32 {
return int32(C.mongocrypt_kms_ctx_bytes_needed(kc.wrapped))
}
// FeedResponse feeds the bytes received from the KMS to mongocrypt.
func (kc *KmsContext) FeedResponse(response []byte) error {
responseBinary := newBinaryFromBytes(response)
defer responseBinary.close()
if ok := C.mongocrypt_kms_ctx_feed(kc.wrapped, responseBinary.wrapped); !ok {
return kc.createErrorFromStatus()
}
return nil
}
// createErrorFromStatus creates a new Error from the status of the KmsContext instance.
func (kc *KmsContext) createErrorFromStatus() error {
status := C.mongocrypt_status_new()
defer C.mongocrypt_status_destroy(status)
C.mongocrypt_kms_ctx_status(kc.wrapped, status)
return errorFromStatus(status)
}
// RequestError returns the source of the network error for KMS requests.
func (kc *KmsContext) RequestError() error {
if bool(C.mongocrypt_kms_ctx_fail(kc.wrapped)) {
return nil
}
return kc.createErrorFromStatus()
}

View File

@@ -0,0 +1,43 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !cse
package mongocrypt
// KmsContext represents a mongocrypt_kms_ctx_t handle.
type KmsContext struct{}
// HostName gets the host name of the KMS.
func (kc *KmsContext) HostName() (string, error) {
panic(cseNotSupportedMsg)
}
// Message returns the message to send to the KMS.
func (kc *KmsContext) Message() ([]byte, error) {
panic(cseNotSupportedMsg)
}
// KMSProvider gets the KMS provider of the KMS context.
func (kc *KmsContext) KMSProvider() string {
panic(cseNotSupportedMsg)
}
// BytesNeeded returns the number of bytes that should be received from the KMS.
// After sending the message to the KMS, this message should be called in a loop until the number returned is 0.
func (kc *KmsContext) BytesNeeded() int32 {
panic(cseNotSupportedMsg)
}
// FeedResponse feeds the bytes received from the KMS to mongocrypt.
func (kc *KmsContext) FeedResponse([]byte) error {
panic(cseNotSupportedMsg)
}
// RequestError returns the source of the network error for KMS requests.
func (kc *KmsContext) RequestError() error {
panic(cseNotSupportedMsg)
}

View File

@@ -0,0 +1,96 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !cse
// Package mongocrypt is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package mongocrypt
import (
"context"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt/options"
)
const cseNotSupportedMsg = "client-side encryption not enabled. add the cse build tag to support"
// MongoCrypt represents a mongocrypt_t handle.
type MongoCrypt struct{}
// Version returns the version string for the loaded libmongocrypt, or an empty string
// if libmongocrypt was not loaded.
func Version() string {
return ""
}
// NewMongoCrypt constructs a new MongoCrypt instance configured using the provided MongoCryptOptions.
func NewMongoCrypt(*options.MongoCryptOptions) (*MongoCrypt, error) {
panic(cseNotSupportedMsg)
}
// CreateEncryptionContext creates a Context to use for encryption.
func (m *MongoCrypt) CreateEncryptionContext(string, bsoncore.Document) (*Context, error) {
panic(cseNotSupportedMsg)
}
// CreateExplicitEncryptionExpressionContext creates a Context to use for explicit encryption of an expression.
func (m *MongoCrypt) CreateExplicitEncryptionExpressionContext(bsoncore.Document, *options.ExplicitEncryptionOptions) (*Context, error) {
panic(cseNotSupportedMsg)
}
// CreateDecryptionContext creates a Context to use for decryption.
func (m *MongoCrypt) CreateDecryptionContext(bsoncore.Document) (*Context, error) {
panic(cseNotSupportedMsg)
}
// CreateDataKeyContext creates a Context to use for creating a data key.
func (m *MongoCrypt) CreateDataKeyContext(string, *options.DataKeyOptions) (*Context, error) {
panic(cseNotSupportedMsg)
}
// CreateExplicitEncryptionContext creates a Context to use for explicit encryption.
func (m *MongoCrypt) CreateExplicitEncryptionContext(bsoncore.Document, *options.ExplicitEncryptionOptions) (*Context, error) {
panic(cseNotSupportedMsg)
}
// RewrapDataKeyContext creates a Context to use for rewrapping a data key.
func (m *MongoCrypt) RewrapDataKeyContext([]byte, *options.RewrapManyDataKeyOptions) (*Context, error) {
panic(cseNotSupportedMsg)
}
// CreateExplicitDecryptionContext creates a Context to use for explicit decryption.
func (m *MongoCrypt) CreateExplicitDecryptionContext(bsoncore.Document) (*Context, error) {
panic(cseNotSupportedMsg)
}
// CryptSharedLibVersion returns the version number for the loaded crypt_shared library, or 0 if the
// crypt_shared library was not loaded.
func (m *MongoCrypt) CryptSharedLibVersion() uint64 {
panic(cseNotSupportedMsg)
}
// CryptSharedLibVersionString returns the version string for the loaded crypt_shared library, or an
// empty string if the crypt_shared library was not loaded.
func (m *MongoCrypt) CryptSharedLibVersionString() string {
panic(cseNotSupportedMsg)
}
// Close cleans up any resources associated with the given MongoCrypt instance.
func (m *MongoCrypt) Close() {
panic(cseNotSupportedMsg)
}
// GetKmsProviders returns the originally configured KMS providers.
func (m *MongoCrypt) GetKmsProviders(context.Context) (bsoncore.Document, error) {
panic(cseNotSupportedMsg)
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package options is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package options

View File

@@ -0,0 +1,132 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package options
import (
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// DataKeyOptions specifies options for creating a new data key.
type DataKeyOptions struct {
KeyAltNames []string
KeyMaterial []byte
MasterKey bsoncore.Document
}
// QueryType describes the type of query the result of Encrypt is used for.
type QueryType int
// These constants specify valid values for QueryType
const (
QueryTypeEquality QueryType = 1
)
// ExplicitEncryptionOptions specifies options for configuring an explicit encryption context.
type ExplicitEncryptionOptions struct {
KeyID *bson.Binary
KeyAltName *string
Algorithm string
QueryType string
ContentionFactor *int64
RangeOptions *ExplicitRangeOptions
TextOptions *ExplicitTextOptions
}
// ExplicitRangeOptions specifies options for the range index.
type ExplicitRangeOptions struct {
Min *bsoncore.Value
Max *bsoncore.Value
Sparsity *int64
TrimFactor *int32
Precision *int32
}
// ExplicitTextOptions specifies options for the text query.
type ExplicitTextOptions struct {
Substring *SubstringOptions
Prefix *PrefixOptions
Suffix *SuffixOptions
CaseSensitive bool
DiacriticSensitive bool
}
// SubstringOptions specifies options to support substring queries.
type SubstringOptions struct {
StrMaxLength int32
StrMinQueryLength int32
StrMaxQueryLength int32
}
// PrefixOptions specifies options to support prefix queries.
type PrefixOptions struct {
StrMinQueryLength int32
StrMaxQueryLength int32
}
// SuffixOptions specifies options to support suffix queries.
type SuffixOptions struct {
StrMinQueryLength int32
StrMaxQueryLength int32
}
// ExplicitEncryption creates a new ExplicitEncryptionOptions instance.
func ExplicitEncryption() *ExplicitEncryptionOptions {
return &ExplicitEncryptionOptions{}
}
// SetKeyID sets the key identifier.
func (eeo *ExplicitEncryptionOptions) SetKeyID(keyID bson.Binary) *ExplicitEncryptionOptions {
eeo.KeyID = &keyID
return eeo
}
// SetKeyAltName sets the key alternative name.
func (eeo *ExplicitEncryptionOptions) SetKeyAltName(keyAltName string) *ExplicitEncryptionOptions {
eeo.KeyAltName = &keyAltName
return eeo
}
// SetAlgorithm specifies an encryption algorithm.
func (eeo *ExplicitEncryptionOptions) SetAlgorithm(algorithm string) *ExplicitEncryptionOptions {
eeo.Algorithm = algorithm
return eeo
}
// SetQueryType specifies the query type.
func (eeo *ExplicitEncryptionOptions) SetQueryType(queryType string) *ExplicitEncryptionOptions {
eeo.QueryType = queryType
return eeo
}
// SetContentionFactor specifies the contention factor.
func (eeo *ExplicitEncryptionOptions) SetContentionFactor(contentionFactor int64) *ExplicitEncryptionOptions {
eeo.ContentionFactor = &contentionFactor
return eeo
}
// SetRangeOptions specifies the range options.
func (eeo *ExplicitEncryptionOptions) SetRangeOptions(ro ExplicitRangeOptions) *ExplicitEncryptionOptions {
eeo.RangeOptions = &ro
return eeo
}
// SetTextOptions specifies the text options.
func (eeo *ExplicitEncryptionOptions) SetTextOptions(to ExplicitTextOptions) *ExplicitEncryptionOptions {
eeo.TextOptions = &to
return eeo
}
// RewrapManyDataKeyOptions represents all possible options used to decrypt and encrypt all matching data keys with a
// possibly new masterKey.
type RewrapManyDataKeyOptions struct {
// Provider identifies the new KMS provider. If omitted, encrypting uses the current KMS provider.
Provider *string
// MasterKey identifies the new masterKey. If omitted, rewraps with the current masterKey.
MasterKey bsoncore.Document
}

View File

@@ -0,0 +1,26 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package options
import (
"net/http"
"time"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// MongoCryptOptions specifies options to configure a MongoCrypt instance.
type MongoCryptOptions struct {
KmsProviders bsoncore.Document
LocalSchemaMap map[string]bsoncore.Document
BypassQueryAnalysis bool
EncryptedFieldsMap map[string]bsoncore.Document
CryptSharedLibDisabled bool
CryptSharedLibOverridePath string
HTTPClient *http.Client
KeyExpiration *time.Duration
}

View File

@@ -0,0 +1,47 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package mongocrypt
// State represents a state that a MongocryptContext can be in.
type State int
// These constants are valid values for the State type.
// The values must match the values defined in the mongocrypt_ctx_state_t enum in libmongocrypt.
const (
StateError State = 0
NeedMongoCollInfo State = 1
NeedMongoMarkings State = 2
NeedMongoKeys State = 3
NeedKms State = 4
Ready State = 5
Done State = 6
NeedKmsCredentials State = 7
)
// String implements the Stringer interface.
func (s State) String() string {
switch s {
case StateError:
return "Error"
case NeedMongoCollInfo:
return "NeedMongoCollInfo"
case NeedMongoMarkings:
return "NeedMongoMarkings"
case NeedMongoKeys:
return "NeedMongoKeys"
case NeedKms:
return "NeedKms"
case Ready:
return "Ready"
case Done:
return "Done"
case NeedKmsCredentials:
return "NeedKmsCredentials"
default:
return "Unknown State"
}
}

View File

@@ -0,0 +1,121 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package ocsp
import (
"crypto"
"sync"
"time"
"golang.org/x/crypto/ocsp"
)
type cacheKey struct {
HashAlgorithm crypto.Hash
IssuerNameHash string
IssuerKeyHash string
SerialNumber string
}
// Cache represents an OCSP cache.
type Cache interface {
Update(*ocsp.Request, *ResponseDetails) *ResponseDetails
Get(request *ocsp.Request) *ResponseDetails
}
// ConcurrentCache is an implementation of ocsp.Cache that's safe for concurrent use.
type ConcurrentCache struct {
cache map[cacheKey]*ResponseDetails
sync.Mutex
}
var _ Cache = (*ConcurrentCache)(nil)
// NewCache creates an empty OCSP cache.
func NewCache() *ConcurrentCache {
return &ConcurrentCache{
cache: make(map[cacheKey]*ResponseDetails),
}
}
// Update updates the cache entry for the provided request. The provided response will only be cached if it has a
// status that is not ocsp.Unknown and has a non-zero NextUpdate time. If there is an existing cache entry for request,
// it will be overwritten by response if response.NextUpdate is further ahead in the future than the existing entry's
// NextUpdate.
//
// This function returns the most up-to-date response corresponding to the request.
func (c *ConcurrentCache) Update(request *ocsp.Request, response *ResponseDetails) *ResponseDetails {
unknown := response.Status == ocsp.Unknown
hasUpdateTime := !response.NextUpdate.IsZero()
canBeCached := !unknown && hasUpdateTime
key := createCacheKey(request)
c.Lock()
defer c.Unlock()
current, ok := c.cache[key]
if !ok {
if canBeCached {
c.cache[key] = response
}
// Return the provided response even though it might not have been cached because it's the most up-to-date
// response available.
return response
}
// If the new response is Unknown, we can't cache it. Return the existing cached response.
if unknown {
return current
}
// If a response has no nextUpdate set, the responder is telling us that newer information is always available.
// In this case, remove the existing cache entry because it is stale and return the new response because it is
// more up-to-date.
if !hasUpdateTime {
delete(c.cache, key)
return response
}
// If we get here, the new response is conclusive and has a non-empty nextUpdate so it can be cached. Overwrite
// the existing cache entry if the new one will be valid for longer.
newest := current
if response.NextUpdate.After(current.NextUpdate) {
c.cache[key] = response
newest = response
}
return newest
}
// Get returns the cached response for the request, or nil if there is no cached response. If the cached response has
// expired, it will be removed from the cache and nil will be returned.
func (c *ConcurrentCache) Get(request *ocsp.Request) *ResponseDetails {
key := createCacheKey(request)
c.Lock()
defer c.Unlock()
response, ok := c.cache[key]
if !ok {
return nil
}
if time.Now().UTC().Before(response.NextUpdate) {
return response
}
delete(c.cache, key)
return nil
}
func createCacheKey(request *ocsp.Request) cacheKey {
return cacheKey{
HashAlgorithm: request.HashAlgorithm,
IssuerNameHash: string(request.IssuerNameHash),
IssuerKeyHash: string(request.IssuerKeyHash),
SerialNumber: request.SerialNumber.String(),
}
}

View File

@@ -0,0 +1,68 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package ocsp
import (
"crypto/x509"
"errors"
"fmt"
"net/http"
"go.mongodb.org/mongo-driver/v2/internal/httputil"
"golang.org/x/crypto/ocsp"
)
type config struct {
serverCert, issuer *x509.Certificate
cache Cache
disableEndpointChecking bool
ocspRequest *ocsp.Request
ocspRequestBytes []byte
httpClient *http.Client
}
func newConfig(certChain []*x509.Certificate, opts *VerifyOptions) (config, error) {
cfg := config{
cache: opts.Cache,
disableEndpointChecking: opts.DisableEndpointChecking,
httpClient: opts.HTTPClient,
}
if cfg.httpClient == nil {
cfg.httpClient = httputil.DefaultHTTPClient
}
if len(certChain) == 0 {
return cfg, errors.New("verified certificate chain contained no certificates")
}
// In the case where the leaf certificate and CA are the same, the chain may only contain one certificate.
cfg.serverCert = certChain[0]
cfg.issuer = certChain[0]
if len(certChain) > 1 {
// If the chain has multiple certificates, the one directly after the leaf should be the issuer. Use
// CheckSignatureFrom to verify that it is the issuer.
cfg.issuer = certChain[1]
if err := cfg.serverCert.CheckSignatureFrom(cfg.issuer); err != nil {
errString := "error checking if server certificate is signed by the issuer in the verified chain: %v"
return cfg, fmt.Errorf(errString, err)
}
}
var err error
cfg.ocspRequestBytes, err = ocsp.CreateRequest(cfg.serverCert, cfg.issuer, nil)
if err != nil {
return cfg, fmt.Errorf("error creating OCSP request: %w", err)
}
cfg.ocspRequest, err = ocsp.ParseRequest(cfg.ocspRequestBytes)
if err != nil {
return cfg, fmt.Errorf("error parsing OCSP request bytes: %w", err)
}
return cfg, nil
}

View File

@@ -0,0 +1,328 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package ocsp is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package ocsp
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"time"
"golang.org/x/crypto/ocsp"
"golang.org/x/sync/errgroup"
)
var (
tlsFeatureExtensionOID = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 24}
mustStapleFeatureValue = big.NewInt(5)
)
// Error represents an OCSP verification error
type Error struct {
wrapped error
}
// Error implements the error interface
func (e *Error) Error() string {
return fmt.Sprintf("OCSP verification failed: %v", e.wrapped)
}
// Unwrap returns the underlying error.
func (e *Error) Unwrap() error {
return e.wrapped
}
func newOCSPError(wrapped error) error {
return &Error{wrapped: wrapped}
}
// ResponseDetails contains a subset of the details needed from an OCSP response after the original response has been
// validated.
type ResponseDetails struct {
Status int
NextUpdate time.Time
}
func extractResponseDetails(res *ocsp.Response) *ResponseDetails {
return &ResponseDetails{
Status: res.Status,
NextUpdate: res.NextUpdate,
}
}
// Verify performs OCSP verification for the provided ConnectionState instance.
func Verify(ctx context.Context, connState tls.ConnectionState, opts *VerifyOptions) error {
if opts.Cache == nil {
// There should always be an OCSP cache. Even if the user has specified the URI option to disable communication
// with OCSP responders, the driver will cache any stapled responses. Requiring that the cache is non-nil
// allows us to confirm that the cache is correctly being passed down from a higher level.
return newOCSPError(errors.New("no OCSP cache provided"))
}
if len(connState.VerifiedChains) == 0 {
return newOCSPError(errors.New("no verified certificate chains reported after TLS handshake"))
}
certChain := connState.VerifiedChains[0]
if numCerts := len(certChain); numCerts == 0 {
return newOCSPError(errors.New("verified chain contained no certificates"))
}
ocspCfg, err := newConfig(certChain, opts)
if err != nil {
return newOCSPError(err)
}
res, err := getParsedResponse(ctx, ocspCfg, connState)
if err != nil {
return err
}
if res == nil {
// If no response was parsed from the staple and responders, the status of the certificate is unknown, so don't
// error.
return nil
}
if res.Status == ocsp.Revoked {
return newOCSPError(errors.New("certificate is revoked"))
}
return nil
}
// getParsedResponse attempts to parse a response from the stapled OCSP data or by contacting OCSP responders if no
// staple is present.
func getParsedResponse(ctx context.Context, cfg config, connState tls.ConnectionState) (*ResponseDetails, error) {
stapledResponse, err := processStaple(cfg, connState.OCSPResponse)
if err != nil {
return nil, err
}
if stapledResponse != nil {
// If there is a staple, attempt to cache it. The cache.Update call will resolve conflicts with an existing
// cache enry if necessary.
return cfg.cache.Update(cfg.ocspRequest, stapledResponse), nil
}
if cachedResponse := cfg.cache.Get(cfg.ocspRequest); cachedResponse != nil {
return cachedResponse, nil
}
// If there is no stapled or cached response, fall back to querying the responders if that functionality has not
// been disabled.
if cfg.disableEndpointChecking {
return nil, nil
}
externalResponse := contactResponders(ctx, cfg)
if externalResponse == nil {
// None of the responders were available.
return nil, nil
}
// Similar to the stapled response case above, unconditionally call Update and it will either cache the response
// or resolve conflicts if a different connection has cached a response since the previous call to Get.
return cfg.cache.Update(cfg.ocspRequest, externalResponse), nil
}
// processStaple returns the OCSP response from the provided staple. An error will be returned if any of the following
// are true:
//
// 1. cfg.serverCert has the Must-Staple extension but the staple is empty.
// 2. The staple is malformed.
// 3. The staple does not cover cfg.serverCert.
// 4. The OCSP response has an error status.
func processStaple(cfg config, staple []byte) (*ResponseDetails, error) {
mustStaple, err := isMustStapleCertificate(cfg.serverCert)
if err != nil {
return nil, err
}
// If the server has a Must-Staple certificate and the server does not present a stapled OCSP response, error.
if mustStaple && len(staple) == 0 {
return nil, errors.New("server provided a certificate with the Must-Staple extension but did not " +
"provide a stapled OCSP response")
}
if len(staple) == 0 {
return nil, nil
}
parsedResponse, err := ocsp.ParseResponseForCert(staple, cfg.serverCert, cfg.issuer)
if err != nil {
// If the stapled response could not be parsed correctly, error. This can happen if the response is malformed,
// the response does not cover the certificate presented by the server, or if the response contains an error
// status.
return nil, fmt.Errorf("error parsing stapled response: %w", err)
}
if err = verifyResponse(cfg, parsedResponse); err != nil {
return nil, fmt.Errorf("error validating stapled response: %w", err)
}
return extractResponseDetails(parsedResponse), nil
}
// isMustStapleCertificate determines whether or not an X509 certificate is a must-staple certificate.
func isMustStapleCertificate(cert *x509.Certificate) (bool, error) {
var featureExtension pkix.Extension
var foundExtension bool
for _, ext := range cert.Extensions {
if ext.Id.Equal(tlsFeatureExtensionOID) {
featureExtension = ext
foundExtension = true
break
}
}
if !foundExtension {
return false, nil
}
// The value for the TLS feature extension is a sequence of integers. Per the asn1.Unmarshal documentation, an
// integer can be unmarshalled into an int, int32, int64, or *big.Int and unmarshalling will error if the integer
// cannot be encoded into the target type.
//
// Use []*big.Int to ensure that all values in the sequence can be successfully unmarshalled.
var featureValues []*big.Int
if _, err := asn1.Unmarshal(featureExtension.Value, &featureValues); err != nil {
return false, fmt.Errorf("error unmarshalling TLS feature extension values: %w", err)
}
for _, value := range featureValues {
if value.Cmp(mustStapleFeatureValue) == 0 {
return true, nil
}
}
return false, nil
}
// contactResponders will send a request to all OCSP responders reported by cfg.serverCert. The
// first response that conclusively identifies cfg.serverCert as good or revoked will be returned.
// If all responders are unavailable or no responder returns a conclusive status, it returns nil.
// contactResponders will wait for up to 5 seconds to get a certificate status response.
func contactResponders(ctx context.Context, cfg config) *ResponseDetails {
if len(cfg.serverCert.OCSPServer) == 0 {
return nil
}
// Limit all OCSP responder calls to a maximum of 5 seconds or when the passed-in context expires,
// whichever happens first.
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
group, ctx := errgroup.WithContext(ctx)
ocspResponses := make(chan *ocsp.Response, len(cfg.serverCert.OCSPServer))
defer close(ocspResponses)
for _, endpoint := range cfg.serverCert.OCSPServer {
// Re-assign endpoint so it gets re-scoped rather than using the iteration variable in the goroutine. See
// https://golang.org/doc/faq#closures_and_goroutines.
endpoint := endpoint
// Start a group of goroutines that each attempt to request the certificate status from one
// of the OCSP endpoints listed in the server certificate. We want to "soft fail" on all
// errors, so this function never returns actual errors. Only a "done" error is returned
// when a response is received so the errgroup cancels any other in-progress requests.
group.Go(func() error {
// Use bytes.NewReader instead of bytes.NewBuffer because a bytes.Buffer is an owning representation and the
// docs recommend not using the underlying []byte after creating the buffer, so a new copy of the request
// bytes would be needed for each request.
request, err := http.NewRequest("POST", endpoint, bytes.NewReader(cfg.ocspRequestBytes))
if err != nil {
return nil
}
request = request.WithContext(ctx)
httpResponse, err := cfg.httpClient.Do(request)
if err != nil {
return nil
}
defer func() {
_ = httpResponse.Body.Close()
}()
if httpResponse.StatusCode != 200 {
return nil
}
httpBytes, err := ioutil.ReadAll(httpResponse.Body)
if err != nil {
return nil
}
ocspResponse, err := ocsp.ParseResponseForCert(httpBytes, cfg.serverCert, cfg.issuer)
if err != nil || verifyResponse(cfg, ocspResponse) != nil || ocspResponse.Status == ocsp.Unknown {
// If there was an error parsing/validating the response or the response was
// inconclusive, suppress the error because we want to ignore this responder.
return nil
}
// Send the conclusive response on the response channel and return a "done" error that
// will cause the errgroup to cancel all other in-progress requests.
ocspResponses <- ocspResponse
return errors.New("done")
})
}
_ = group.Wait()
select {
case res := <-ocspResponses:
return extractResponseDetails(res)
default:
// If there is no OCSP response on the response channel, all OCSP calls either failed or
// were inconclusive. Return nil.
return nil
}
}
// verifyResponse checks that the provided OCSP response is valid.
func verifyResponse(cfg config, res *ocsp.Response) error {
if err := verifyExtendedKeyUsage(cfg, res); err != nil {
return err
}
currTime := time.Now().UTC()
if res.ThisUpdate.After(currTime) {
return fmt.Errorf("reported thisUpdate time %s is after current time %s", res.ThisUpdate, currTime)
}
if !res.NextUpdate.IsZero() && res.NextUpdate.Before(currTime) {
return fmt.Errorf("reported nextUpdate time %s is before current time %s", res.NextUpdate, currTime)
}
return nil
}
func verifyExtendedKeyUsage(cfg config, res *ocsp.Response) error {
if res.Certificate == nil {
return nil
}
namesMatch := res.RawResponderName != nil && bytes.Equal(res.RawResponderName, cfg.issuer.RawSubject)
keyHashesMatch := res.ResponderKeyHash != nil && bytes.Equal(res.ResponderKeyHash, cfg.ocspRequest.IssuerKeyHash)
if namesMatch || keyHashesMatch {
// The responder certificate is the same as the issuer certificate.
return nil
}
// There is a delegate.
for _, extKeyUsage := range res.Certificate.ExtKeyUsage {
if extKeyUsage == x509.ExtKeyUsageOCSPSigning {
return nil
}
}
return errors.New("delegate responder certificate is missing the OCSP signing extended key usage")
}

View File

@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package ocsp
import "net/http"
// VerifyOptions specifies options to configure OCSP verification.
type VerifyOptions struct {
Cache Cache
DisableEndpointChecking bool
HTTPClient *http.Client
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/internal/logger"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// AbortTransaction performs an abortTransaction operation.
type AbortTransaction struct {
authenticator driver.Authenticator
recoveryToken bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
serverAPI *driver.ServerAPIOptions
logger *logger.Logger
}
// NewAbortTransaction constructs and returns a new AbortTransaction.
func NewAbortTransaction() *AbortTransaction {
return &AbortTransaction{}
}
func (at *AbortTransaction) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error {
return nil
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (at *AbortTransaction) Execute(ctx context.Context) error {
if at.deployment == nil {
return errors.New("the AbortTransaction operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: at.command,
ProcessResponseFn: at.processResponse,
RetryMode: at.retry,
Type: driver.Write,
Client: at.session,
Clock: at.clock,
CommandMonitor: at.monitor,
Crypt: at.crypt,
Database: at.database,
Deployment: at.deployment,
Selector: at.selector,
WriteConcern: at.writeConcern,
ServerAPI: at.serverAPI,
Name: driverutil.AbortTransactionOp,
Authenticator: at.authenticator,
Logger: at.logger,
}.Execute(ctx)
}
func (at *AbortTransaction) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "abortTransaction", 1)
if at.recoveryToken != nil {
dst = bsoncore.AppendDocumentElement(dst, "recoveryToken", at.recoveryToken)
}
return dst, nil
}
// RecoveryToken sets the recovery token to use when committing or aborting a sharded transaction.
func (at *AbortTransaction) RecoveryToken(recoveryToken bsoncore.Document) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.recoveryToken = recoveryToken
return at
}
// Session sets the session for this operation.
func (at *AbortTransaction) Session(session *session.Client) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.session = session
return at
}
// ClusterClock sets the cluster clock for this operation.
func (at *AbortTransaction) ClusterClock(clock *session.ClusterClock) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.clock = clock
return at
}
// Collection sets the collection that this command will run against.
func (at *AbortTransaction) Collection(collection string) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.collection = collection
return at
}
// CommandMonitor sets the monitor to use for APM events.
func (at *AbortTransaction) CommandMonitor(monitor *event.CommandMonitor) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.monitor = monitor
return at
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (at *AbortTransaction) Crypt(crypt driver.Crypt) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.crypt = crypt
return at
}
// Database sets the database to run this operation against.
func (at *AbortTransaction) Database(database string) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.database = database
return at
}
// Deployment sets the deployment to use for this operation.
func (at *AbortTransaction) Deployment(deployment driver.Deployment) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.deployment = deployment
return at
}
// ServerSelector sets the selector used to retrieve a server.
func (at *AbortTransaction) ServerSelector(selector description.ServerSelector) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.selector = selector
return at
}
// WriteConcern sets the write concern for this operation.
func (at *AbortTransaction) WriteConcern(writeConcern *writeconcern.WriteConcern) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.writeConcern = writeConcern
return at
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (at *AbortTransaction) Retry(retry driver.RetryMode) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.retry = &retry
return at
}
// ServerAPI sets the server API version for this operation.
func (at *AbortTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.serverAPI = serverAPI
return at
}
// Authenticator sets the authenticator to use for this operation.
func (at *AbortTransaction) Authenticator(authenticator driver.Authenticator) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.authenticator = authenticator
return at
}
// Logger sets the logger for this operation.
func (at *AbortTransaction) Logger(logger *logger.Logger) *AbortTransaction {
if at == nil {
at = new(AbortTransaction)
}
at.logger = logger
return at
}

View File

@@ -0,0 +1,441 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Aggregate represents an aggregate operation.
type Aggregate struct {
authenticator driver.Authenticator
allowDiskUse *bool
batchSize *int32
bypassDocumentValidation *bool
collation bsoncore.Document
comment bsoncore.Value
hint bsoncore.Value
pipeline bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
readConcern *readconcern.ReadConcern
readPreference *readpref.ReadPref
retry *driver.RetryMode
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
crypt driver.Crypt
serverAPI *driver.ServerAPIOptions
let bsoncore.Document
hasOutputStage bool
customOptions map[string]bsoncore.Value
timeout *time.Duration
omitMaxTimeMS bool
rawData *bool
result driver.CursorResponse
}
// NewAggregate constructs and returns a new Aggregate.
func NewAggregate(pipeline bsoncore.Document) *Aggregate {
return &Aggregate{
pipeline: pipeline,
}
}
// Result returns the result of executing this operation.
func (a *Aggregate) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) {
clientSession := a.session
clock := a.clock
opts.ServerAPI = a.serverAPI
return driver.NewBatchCursor(a.result, clientSession, clock, opts)
}
// ResultCursorResponse returns the underlying CursorResponse result of executing this
// operation.
func (a *Aggregate) ResultCursorResponse() driver.CursorResponse {
return a.result
}
func (a *Aggregate) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error {
curDoc, err := driver.ExtractCursorDocument(resp)
if err != nil {
return err
}
a.result, err = driver.NewCursorResponse(curDoc, info)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (a *Aggregate) Execute(ctx context.Context) error {
if a.deployment == nil {
return errors.New("the Aggregate operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: a.command,
ProcessResponseFn: a.processResponse,
Client: a.session,
Clock: a.clock,
CommandMonitor: a.monitor,
Database: a.database,
Deployment: a.deployment,
ReadConcern: a.readConcern,
ReadPreference: a.readPreference,
Type: driver.Read,
RetryMode: a.retry,
Selector: a.selector,
WriteConcern: a.writeConcern,
Crypt: a.crypt,
MinimumWriteConcernWireVersion: 5,
ServerAPI: a.serverAPI,
IsOutputAggregate: a.hasOutputStage,
Timeout: a.timeout,
Name: driverutil.AggregateOp,
Authenticator: a.authenticator,
OmitMaxTimeMS: a.omitMaxTimeMS,
}.Execute(ctx)
}
func (a *Aggregate) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
header := bsoncore.Value{Type: bsoncore.TypeString, Data: bsoncore.AppendString(nil, a.collection)}
if a.collection == "" {
header = bsoncore.Value{Type: bsoncore.TypeInt32, Data: []byte{0x01, 0x00, 0x00, 0x00}}
}
dst = bsoncore.AppendValueElement(dst, "aggregate", header)
cursorIdx, cursorDoc := bsoncore.AppendDocumentStart(nil)
if a.allowDiskUse != nil {
dst = bsoncore.AppendBooleanElement(dst, "allowDiskUse", *a.allowDiskUse)
}
if a.batchSize != nil {
cursorDoc = bsoncore.AppendInt32Element(cursorDoc, "batchSize", *a.batchSize)
}
if a.bypassDocumentValidation != nil {
dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *a.bypassDocumentValidation)
}
if a.collation != nil {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", a.collation)
}
if a.comment.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "comment", a.comment)
}
if a.hint.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "hint", a.hint)
}
if a.pipeline != nil {
dst = bsoncore.AppendArrayElement(dst, "pipeline", a.pipeline)
}
if a.let != nil {
dst = bsoncore.AppendDocumentElement(dst, "let", a.let)
}
// Set rawData for 8.2+ servers.
if a.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *a.rawData)
}
for optionName, optionValue := range a.customOptions {
dst = bsoncore.AppendValueElement(dst, optionName, optionValue)
}
cursorDoc, _ = bsoncore.AppendDocumentEnd(cursorDoc, cursorIdx)
dst = bsoncore.AppendDocumentElement(dst, "cursor", cursorDoc)
return dst, nil
}
// AllowDiskUse enables writing to temporary files. When true, aggregation stages can write to the dbPath/_tmp directory.
func (a *Aggregate) AllowDiskUse(allowDiskUse bool) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.allowDiskUse = &allowDiskUse
return a
}
// BatchSize specifies the number of documents to return in every batch.
func (a *Aggregate) BatchSize(batchSize int32) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.batchSize = &batchSize
return a
}
// BypassDocumentValidation allows the write to opt-out of document level validation. This only applies when the $out stage is specified.
func (a *Aggregate) BypassDocumentValidation(bypassDocumentValidation bool) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.bypassDocumentValidation = &bypassDocumentValidation
return a
}
// Collation specifies a collation.
func (a *Aggregate) Collation(collation bsoncore.Document) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.collation = collation
return a
}
// Comment sets a value to help trace an operation.
func (a *Aggregate) Comment(comment bsoncore.Value) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.comment = comment
return a
}
// Hint specifies the index to use.
func (a *Aggregate) Hint(hint bsoncore.Value) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.hint = hint
return a
}
// Pipeline determines how data is transformed for an aggregation.
func (a *Aggregate) Pipeline(pipeline bsoncore.Document) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.pipeline = pipeline
return a
}
// Session sets the session for this operation.
func (a *Aggregate) Session(session *session.Client) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.session = session
return a
}
// ClusterClock sets the cluster clock for this operation.
func (a *Aggregate) ClusterClock(clock *session.ClusterClock) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.clock = clock
return a
}
// Collection sets the collection that this command will run against.
func (a *Aggregate) Collection(collection string) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.collection = collection
return a
}
// CommandMonitor sets the monitor to use for APM events.
func (a *Aggregate) CommandMonitor(monitor *event.CommandMonitor) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.monitor = monitor
return a
}
// Database sets the database to run this operation against.
func (a *Aggregate) Database(database string) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.database = database
return a
}
// Deployment sets the deployment to use for this operation.
func (a *Aggregate) Deployment(deployment driver.Deployment) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.deployment = deployment
return a
}
// ReadConcern specifies the read concern for this operation.
func (a *Aggregate) ReadConcern(readConcern *readconcern.ReadConcern) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.readConcern = readConcern
return a
}
// ReadPreference set the read preference used with this operation.
func (a *Aggregate) ReadPreference(readPreference *readpref.ReadPref) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.readPreference = readPreference
return a
}
// ServerSelector sets the selector used to retrieve a server.
func (a *Aggregate) ServerSelector(selector description.ServerSelector) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.selector = selector
return a
}
// WriteConcern sets the write concern for this operation.
func (a *Aggregate) WriteConcern(writeConcern *writeconcern.WriteConcern) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.writeConcern = writeConcern
return a
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (a *Aggregate) Retry(retry driver.RetryMode) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.retry = &retry
return a
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (a *Aggregate) Crypt(crypt driver.Crypt) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.crypt = crypt
return a
}
// ServerAPI sets the server API version for this operation.
func (a *Aggregate) ServerAPI(serverAPI *driver.ServerAPIOptions) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.serverAPI = serverAPI
return a
}
// Let specifies the let document to use. This option is only valid for server versions 5.0 and above.
func (a *Aggregate) Let(let bsoncore.Document) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.let = let
return a
}
// HasOutputStage specifies whether the aggregate contains an output stage. Used in determining when to
// append read preference at the operation level.
func (a *Aggregate) HasOutputStage(hos bool) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.hasOutputStage = hos
return a
}
// CustomOptions specifies extra options to use in the aggregate command.
func (a *Aggregate) CustomOptions(co map[string]bsoncore.Value) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.customOptions = co
return a
}
// Timeout sets the timeout for this operation.
func (a *Aggregate) Timeout(timeout *time.Duration) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.timeout = timeout
return a
}
// Authenticator sets the authenticator to use for this operation.
func (a *Aggregate) Authenticator(authenticator driver.Authenticator) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.authenticator = authenticator
return a
}
// OmitMaxTimeMS omits the automatically-calculated "maxTimeMS" from the
// command.
func (a *Aggregate) OmitMaxTimeMS(omit bool) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.omitMaxTimeMS = omit
return a
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (a *Aggregate) RawData(rawData bool) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.rawData = &rawData
return a
}

View File

@@ -0,0 +1,237 @@
// Copyright (C) MongoDB, Inc. 2021-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/logger"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Command is used to run a generic operation.
type Command struct {
authenticator driver.Authenticator
command bsoncore.Document
database string
deployment driver.Deployment
selector description.ServerSelector
readPreference *readpref.ReadPref
clock *session.ClusterClock
session *session.Client
monitor *event.CommandMonitor
resultResponse bsoncore.Document
resultCursor *driver.BatchCursor
crypt driver.Crypt
serverAPI *driver.ServerAPIOptions
createCursor bool
cursorOpts driver.CursorOptions
timeout *time.Duration
logger *logger.Logger
}
// NewCommand constructs and returns a new Command. Once the operation is executed, the result may only be accessed via
// the Result() function.
func NewCommand(command bsoncore.Document) *Command {
return &Command{
command: command,
}
}
// NewCursorCommand constructs a new Command. Once the operation is executed, the server response will be used to
// construct a cursor, which can be accessed via the ResultCursor() function.
func NewCursorCommand(command bsoncore.Document, cursorOpts driver.CursorOptions) *Command {
return &Command{
command: command,
cursorOpts: cursorOpts,
createCursor: true,
}
}
// Result returns the result of executing this operation.
func (c *Command) Result() bsoncore.Document { return c.resultResponse }
// ResultCursor returns the BatchCursor that was constructed using the command response. If the operation was not
// configured to create a cursor (i.e. it was created using NewCommand rather than NewCursorCommand), this function
// will return nil and an error.
func (c *Command) ResultCursor() (*driver.BatchCursor, error) {
if !c.createCursor {
return nil, errors.New("command operation was not configured to create a cursor, but a result cursor was requested")
}
return c.resultCursor, nil
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (c *Command) Execute(ctx context.Context) error {
if c.deployment == nil {
return errors.New("the Command operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
return append(dst, c.command[4:len(c.command)-1]...), nil
},
ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error {
c.resultResponse = resp
if c.createCursor {
curDoc, err := driver.ExtractCursorDocument(resp)
if err != nil {
return err
}
cursorRes, err := driver.NewCursorResponse(curDoc, info)
if err != nil {
return err
}
c.resultCursor, err = driver.NewBatchCursor(cursorRes, c.session, c.clock, c.cursorOpts)
return err
}
return nil
},
Client: c.session,
Clock: c.clock,
CommandMonitor: c.monitor,
Database: c.database,
Deployment: c.deployment,
ReadPreference: c.readPreference,
Selector: c.selector,
Crypt: c.crypt,
ServerAPI: c.serverAPI,
Timeout: c.timeout,
Logger: c.logger,
Authenticator: c.authenticator,
}.Execute(ctx)
}
// Session sets the session for this operation.
func (c *Command) Session(session *session.Client) *Command {
if c == nil {
c = new(Command)
}
c.session = session
return c
}
// ClusterClock sets the cluster clock for this operation.
func (c *Command) ClusterClock(clock *session.ClusterClock) *Command {
if c == nil {
c = new(Command)
}
c.clock = clock
return c
}
// CommandMonitor sets the monitor to use for APM events.
func (c *Command) CommandMonitor(monitor *event.CommandMonitor) *Command {
if c == nil {
c = new(Command)
}
c.monitor = monitor
return c
}
// Database sets the database to run this operation against.
func (c *Command) Database(database string) *Command {
if c == nil {
c = new(Command)
}
c.database = database
return c
}
// Deployment sets the deployment to use for this operation.
func (c *Command) Deployment(deployment driver.Deployment) *Command {
if c == nil {
c = new(Command)
}
c.deployment = deployment
return c
}
// ReadPreference set the read preference used with this operation.
func (c *Command) ReadPreference(readPreference *readpref.ReadPref) *Command {
if c == nil {
c = new(Command)
}
c.readPreference = readPreference
return c
}
// ServerSelector sets the selector used to retrieve a server.
func (c *Command) ServerSelector(selector description.ServerSelector) *Command {
if c == nil {
c = new(Command)
}
c.selector = selector
return c
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (c *Command) Crypt(crypt driver.Crypt) *Command {
if c == nil {
c = new(Command)
}
c.crypt = crypt
return c
}
// ServerAPI sets the server API version for this operation.
func (c *Command) ServerAPI(serverAPI *driver.ServerAPIOptions) *Command {
if c == nil {
c = new(Command)
}
c.serverAPI = serverAPI
return c
}
// Timeout sets the timeout for this operation.
func (c *Command) Timeout(timeout *time.Duration) *Command {
if c == nil {
c = new(Command)
}
c.timeout = timeout
return c
}
// Logger sets the logger for this operation.
func (c *Command) Logger(logger *logger.Logger) *Command {
if c == nil {
c = new(Command)
}
c.logger = logger
return c
}
// Authenticator sets the authenticator to use for this operation.
func (c *Command) Authenticator(authenticator driver.Authenticator) *Command {
if c == nil {
c = new(Command)
}
c.authenticator = authenticator
return c
}

View File

@@ -0,0 +1,212 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/internal/logger"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// CommitTransaction attempts to commit a transaction.
type CommitTransaction struct {
authenticator driver.Authenticator
recoveryToken bsoncore.Document
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
serverAPI *driver.ServerAPIOptions
logger *logger.Logger
}
// NewCommitTransaction constructs and returns a new CommitTransaction.
func NewCommitTransaction() *CommitTransaction {
return &CommitTransaction{}
}
func (ct *CommitTransaction) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error {
return nil
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (ct *CommitTransaction) Execute(ctx context.Context) error {
if ct.deployment == nil {
return errors.New("the CommitTransaction operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: ct.command,
ProcessResponseFn: ct.processResponse,
RetryMode: ct.retry,
Type: driver.Write,
Client: ct.session,
Clock: ct.clock,
CommandMonitor: ct.monitor,
Crypt: ct.crypt,
Database: ct.database,
Deployment: ct.deployment,
Selector: ct.selector,
WriteConcern: ct.writeConcern,
ServerAPI: ct.serverAPI,
Name: driverutil.CommitTransactionOp,
Authenticator: ct.authenticator,
Logger: ct.logger,
}.Execute(ctx)
}
func (ct *CommitTransaction) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "commitTransaction", 1)
if ct.recoveryToken != nil {
dst = bsoncore.AppendDocumentElement(dst, "recoveryToken", ct.recoveryToken)
}
return dst, nil
}
// RecoveryToken sets the recovery token to use when committing or aborting a sharded transaction.
func (ct *CommitTransaction) RecoveryToken(recoveryToken bsoncore.Document) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.recoveryToken = recoveryToken
return ct
}
// Session sets the session for this operation.
func (ct *CommitTransaction) Session(session *session.Client) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.session = session
return ct
}
// ClusterClock sets the cluster clock for this operation.
func (ct *CommitTransaction) ClusterClock(clock *session.ClusterClock) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.clock = clock
return ct
}
// CommandMonitor sets the monitor to use for APM events.
func (ct *CommitTransaction) CommandMonitor(monitor *event.CommandMonitor) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.monitor = monitor
return ct
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (ct *CommitTransaction) Crypt(crypt driver.Crypt) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.crypt = crypt
return ct
}
// Database sets the database to run this operation against.
func (ct *CommitTransaction) Database(database string) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.database = database
return ct
}
// Deployment sets the deployment to use for this operation.
func (ct *CommitTransaction) Deployment(deployment driver.Deployment) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.deployment = deployment
return ct
}
// ServerSelector sets the selector used to retrieve a server.
func (ct *CommitTransaction) ServerSelector(selector description.ServerSelector) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.selector = selector
return ct
}
// WriteConcern sets the write concern for this operation.
func (ct *CommitTransaction) WriteConcern(writeConcern *writeconcern.WriteConcern) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.writeConcern = writeConcern
return ct
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (ct *CommitTransaction) Retry(retry driver.RetryMode) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.retry = &retry
return ct
}
// ServerAPI sets the server API version for this operation.
func (ct *CommitTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.serverAPI = serverAPI
return ct
}
// Authenticator sets the authenticator to use for this operation.
func (ct *CommitTransaction) Authenticator(authenticator driver.Authenticator) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.authenticator = authenticator
return ct
}
// Logger sets the logger for this operation.
func (ct *CommitTransaction) Logger(logger *logger.Logger) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.logger = logger
return ct
}

View File

@@ -0,0 +1,326 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Count represents a count operation.
type Count struct {
authenticator driver.Authenticator
query bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
comment bsoncore.Value
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
readConcern *readconcern.ReadConcern
readPreference *readpref.ReadPref
selector description.ServerSelector
retry *driver.RetryMode
result CountResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
rawData *bool
}
// CountResult represents a count result returned by the server.
type CountResult struct {
// The number of documents found
N int64
}
func buildCountResult(response bsoncore.Document) (CountResult, error) {
elements, err := response.Elements()
if err != nil {
return CountResult{}, err
}
cr := CountResult{}
for _, element := range elements {
switch element.Key() {
case "n": // for count using original command
var ok bool
cr.N, ok = element.Value().AsInt64OK()
if !ok {
return cr, fmt.Errorf("response field 'n' is type int64, but received BSON type %s",
element.Value().Type)
}
case "cursor": // for count using aggregate with $collStats
firstBatch, err := element.Value().Document().LookupErr("firstBatch")
if err != nil {
return cr, err
}
// get count value from first batch
val := firstBatch.Array().Index(0)
count, err := val.Document().LookupErr("n")
if err != nil {
return cr, err
}
// use count as Int64 for result
var ok bool
cr.N, ok = count.AsInt64OK()
if !ok {
return cr, fmt.Errorf("response field 'n' is type int64, but received BSON type %s",
element.Value().Type)
}
}
}
return cr, nil
}
// NewCount constructs and returns a new Count.
func NewCount() *Count {
return &Count{}
}
// Result returns the result of executing this operation.
func (c *Count) Result() CountResult { return c.result }
func (c *Count) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
c.result, err = buildCountResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (c *Count) Execute(ctx context.Context) error {
if c.deployment == nil {
return errors.New("the Count operation must have a Deployment set before Execute can be called")
}
err := driver.Operation{
CommandFn: c.command,
ProcessResponseFn: c.processResponse,
RetryMode: c.retry,
Type: driver.Read,
Client: c.session,
Clock: c.clock,
CommandMonitor: c.monitor,
Crypt: c.crypt,
Database: c.database,
Deployment: c.deployment,
ReadConcern: c.readConcern,
ReadPreference: c.readPreference,
Selector: c.selector,
ServerAPI: c.serverAPI,
Timeout: c.timeout,
Name: driverutil.CountOp,
Authenticator: c.authenticator,
}.Execute(ctx)
// Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace
if err != nil {
dErr, ok := err.(driver.Error)
if ok && dErr.Code == 26 {
err = nil
}
}
return err
}
func (c *Count) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "count", c.collection)
if c.query != nil {
dst = bsoncore.AppendDocumentElement(dst, "query", c.query)
}
if c.comment.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "comment", c.comment)
}
// Set rawData for 8.2+ servers.
if c.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *c.rawData)
}
return dst, nil
}
// Query determines what results are returned from find.
func (c *Count) Query(query bsoncore.Document) *Count {
if c == nil {
c = new(Count)
}
c.query = query
return c
}
// Session sets the session for this operation.
func (c *Count) Session(session *session.Client) *Count {
if c == nil {
c = new(Count)
}
c.session = session
return c
}
// ClusterClock sets the cluster clock for this operation.
func (c *Count) ClusterClock(clock *session.ClusterClock) *Count {
if c == nil {
c = new(Count)
}
c.clock = clock
return c
}
// Collection sets the collection that this command will run against.
func (c *Count) Collection(collection string) *Count {
if c == nil {
c = new(Count)
}
c.collection = collection
return c
}
// Comment sets a value to help trace an operation.
func (c *Count) Comment(comment bsoncore.Value) *Count {
if c == nil {
c = new(Count)
}
c.comment = comment
return c
}
// CommandMonitor sets the monitor to use for APM events.
func (c *Count) CommandMonitor(monitor *event.CommandMonitor) *Count {
if c == nil {
c = new(Count)
}
c.monitor = monitor
return c
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (c *Count) Crypt(crypt driver.Crypt) *Count {
if c == nil {
c = new(Count)
}
c.crypt = crypt
return c
}
// Database sets the database to run this operation against.
func (c *Count) Database(database string) *Count {
if c == nil {
c = new(Count)
}
c.database = database
return c
}
// Deployment sets the deployment to use for this operation.
func (c *Count) Deployment(deployment driver.Deployment) *Count {
if c == nil {
c = new(Count)
}
c.deployment = deployment
return c
}
// ReadConcern specifies the read concern for this operation.
func (c *Count) ReadConcern(readConcern *readconcern.ReadConcern) *Count {
if c == nil {
c = new(Count)
}
c.readConcern = readConcern
return c
}
// ReadPreference set the read preference used with this operation.
func (c *Count) ReadPreference(readPreference *readpref.ReadPref) *Count {
if c == nil {
c = new(Count)
}
c.readPreference = readPreference
return c
}
// ServerSelector sets the selector used to retrieve a server.
func (c *Count) ServerSelector(selector description.ServerSelector) *Count {
if c == nil {
c = new(Count)
}
c.selector = selector
return c
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (c *Count) Retry(retry driver.RetryMode) *Count {
if c == nil {
c = new(Count)
}
c.retry = &retry
return c
}
// ServerAPI sets the server API version for this operation.
func (c *Count) ServerAPI(serverAPI *driver.ServerAPIOptions) *Count {
if c == nil {
c = new(Count)
}
c.serverAPI = serverAPI
return c
}
// Timeout sets the timeout for this operation.
func (c *Count) Timeout(timeout *time.Duration) *Count {
if c == nil {
c = new(Count)
}
c.timeout = timeout
return c
}
// Authenticator sets the authenticator to use for this operation.
func (c *Count) Authenticator(authenticator driver.Authenticator) *Count {
if c == nil {
c = new(Count)
}
c.authenticator = authenticator
return c
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (c *Count) RawData(rawData bool) *Count {
if c == nil {
c = new(Count)
}
c.rawData = &rawData
return c
}

View File

@@ -0,0 +1,414 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Create represents a create operation.
type Create struct {
authenticator driver.Authenticator
capped *bool
collation bsoncore.Document
changeStreamPreAndPostImages bsoncore.Document
collectionName *string
indexOptionDefaults bsoncore.Document
max *int64
pipeline bsoncore.Document
size *int64
storageEngine bsoncore.Document
validationAction *string
validationLevel *string
validator bsoncore.Document
viewOn *string
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
serverAPI *driver.ServerAPIOptions
expireAfterSeconds *int64
timeSeries bsoncore.Document
encryptedFields bsoncore.Document
clusteredIndex bsoncore.Document
}
// NewCreate constructs and returns a new Create.
func NewCreate(collectionName string) *Create {
return &Create{
collectionName: &collectionName,
}
}
func (c *Create) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error {
return nil
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (c *Create) Execute(ctx context.Context) error {
if c.deployment == nil {
return errors.New("the Create operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: c.command,
ProcessResponseFn: c.processResponse,
Client: c.session,
Clock: c.clock,
CommandMonitor: c.monitor,
Crypt: c.crypt,
Database: c.database,
Deployment: c.deployment,
Selector: c.selector,
WriteConcern: c.writeConcern,
ServerAPI: c.serverAPI,
Authenticator: c.authenticator,
}.Execute(ctx)
}
func (c *Create) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
if c.collectionName != nil {
dst = bsoncore.AppendStringElement(dst, "create", *c.collectionName)
}
if c.capped != nil {
dst = bsoncore.AppendBooleanElement(dst, "capped", *c.capped)
}
if c.changeStreamPreAndPostImages != nil {
dst = bsoncore.AppendDocumentElement(dst, "changeStreamPreAndPostImages", c.changeStreamPreAndPostImages)
}
if c.collation != nil {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", c.collation)
}
if c.indexOptionDefaults != nil {
dst = bsoncore.AppendDocumentElement(dst, "indexOptionDefaults", c.indexOptionDefaults)
}
if c.max != nil {
dst = bsoncore.AppendInt64Element(dst, "max", *c.max)
}
if c.pipeline != nil {
dst = bsoncore.AppendArrayElement(dst, "pipeline", c.pipeline)
}
if c.size != nil {
dst = bsoncore.AppendInt64Element(dst, "size", *c.size)
}
if c.storageEngine != nil {
dst = bsoncore.AppendDocumentElement(dst, "storageEngine", c.storageEngine)
}
if c.validationAction != nil {
dst = bsoncore.AppendStringElement(dst, "validationAction", *c.validationAction)
}
if c.validationLevel != nil {
dst = bsoncore.AppendStringElement(dst, "validationLevel", *c.validationLevel)
}
if c.validator != nil {
dst = bsoncore.AppendDocumentElement(dst, "validator", c.validator)
}
if c.viewOn != nil {
dst = bsoncore.AppendStringElement(dst, "viewOn", *c.viewOn)
}
if c.expireAfterSeconds != nil {
dst = bsoncore.AppendInt64Element(dst, "expireAfterSeconds", *c.expireAfterSeconds)
}
if c.timeSeries != nil {
dst = bsoncore.AppendDocumentElement(dst, "timeseries", c.timeSeries)
}
if c.encryptedFields != nil {
dst = bsoncore.AppendDocumentElement(dst, "encryptedFields", c.encryptedFields)
}
if c.clusteredIndex != nil {
dst = bsoncore.AppendDocumentElement(dst, "clusteredIndex", c.clusteredIndex)
}
return dst, nil
}
// Capped specifies if the collection is capped.
func (c *Create) Capped(capped bool) *Create {
if c == nil {
c = new(Create)
}
c.capped = &capped
return c
}
// Collation specifies a collation.
func (c *Create) Collation(collation bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.collation = collation
return c
}
// ChangeStreamPreAndPostImages specifies how change streams opened against the collection can return pre-
// and post-images of updated documents. This option is only valid for server versions 6.0 and above.
func (c *Create) ChangeStreamPreAndPostImages(csppi bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.changeStreamPreAndPostImages = csppi
return c
}
// CollectionName specifies the name of the collection to create.
func (c *Create) CollectionName(collectionName string) *Create {
if c == nil {
c = new(Create)
}
c.collectionName = &collectionName
return c
}
// IndexOptionDefaults specifies a default configuration for indexes on the collection.
func (c *Create) IndexOptionDefaults(indexOptionDefaults bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.indexOptionDefaults = indexOptionDefaults
return c
}
// Max specifies the maximum number of documents allowed in a capped collection.
func (c *Create) Max(max int64) *Create {
if c == nil {
c = new(Create)
}
c.max = &max
return c
}
// Pipeline specifies the agggregtion pipeline to be run against the source to create the view.
func (c *Create) Pipeline(pipeline bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.pipeline = pipeline
return c
}
// Size specifies the maximum size in bytes for a capped collection.
func (c *Create) Size(size int64) *Create {
if c == nil {
c = new(Create)
}
c.size = &size
return c
}
// StorageEngine specifies the storage engine to use for the index.
func (c *Create) StorageEngine(storageEngine bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.storageEngine = storageEngine
return c
}
// ValidationAction specifies what should happen if a document being inserted does not pass validation.
func (c *Create) ValidationAction(validationAction string) *Create {
if c == nil {
c = new(Create)
}
c.validationAction = &validationAction
return c
}
// ValidationLevel specifies how strictly the server applies validation rules to existing documents in the collection
// during update operations.
func (c *Create) ValidationLevel(validationLevel string) *Create {
if c == nil {
c = new(Create)
}
c.validationLevel = &validationLevel
return c
}
// Validator specifies validation rules for the collection.
func (c *Create) Validator(validator bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.validator = validator
return c
}
// ViewOn specifies the name of the source collection or view on which the view will be created.
func (c *Create) ViewOn(viewOn string) *Create {
if c == nil {
c = new(Create)
}
c.viewOn = &viewOn
return c
}
// Session sets the session for this operation.
func (c *Create) Session(session *session.Client) *Create {
if c == nil {
c = new(Create)
}
c.session = session
return c
}
// ClusterClock sets the cluster clock for this operation.
func (c *Create) ClusterClock(clock *session.ClusterClock) *Create {
if c == nil {
c = new(Create)
}
c.clock = clock
return c
}
// CommandMonitor sets the monitor to use for APM events.
func (c *Create) CommandMonitor(monitor *event.CommandMonitor) *Create {
if c == nil {
c = new(Create)
}
c.monitor = monitor
return c
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (c *Create) Crypt(crypt driver.Crypt) *Create {
if c == nil {
c = new(Create)
}
c.crypt = crypt
return c
}
// Database sets the database to run this operation against.
func (c *Create) Database(database string) *Create {
if c == nil {
c = new(Create)
}
c.database = database
return c
}
// Deployment sets the deployment to use for this operation.
func (c *Create) Deployment(deployment driver.Deployment) *Create {
if c == nil {
c = new(Create)
}
c.deployment = deployment
return c
}
// ServerSelector sets the selector used to retrieve a server.
func (c *Create) ServerSelector(selector description.ServerSelector) *Create {
if c == nil {
c = new(Create)
}
c.selector = selector
return c
}
// WriteConcern sets the write concern for this operation.
func (c *Create) WriteConcern(writeConcern *writeconcern.WriteConcern) *Create {
if c == nil {
c = new(Create)
}
c.writeConcern = writeConcern
return c
}
// ServerAPI sets the server API version for this operation.
func (c *Create) ServerAPI(serverAPI *driver.ServerAPIOptions) *Create {
if c == nil {
c = new(Create)
}
c.serverAPI = serverAPI
return c
}
// ExpireAfterSeconds sets the seconds to wait before deleting old time-series data.
func (c *Create) ExpireAfterSeconds(eas int64) *Create {
if c == nil {
c = new(Create)
}
c.expireAfterSeconds = &eas
return c
}
// TimeSeries sets the time series options for this operation.
func (c *Create) TimeSeries(timeSeries bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.timeSeries = timeSeries
return c
}
// EncryptedFields sets the EncryptedFields for this operation.
func (c *Create) EncryptedFields(ef bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.encryptedFields = ef
return c
}
// ClusteredIndex sets the ClusteredIndex option for this operation.
func (c *Create) ClusteredIndex(ci bsoncore.Document) *Create {
if c == nil {
c = new(Create)
}
c.clusteredIndex = ci
return c
}
// Authenticator sets the authenticator to use for this operation.
func (c *Create) Authenticator(authenticator driver.Authenticator) *Create {
if c == nil {
c = new(Create)
}
c.authenticator = authenticator
return c
}

View File

@@ -0,0 +1,293 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// CreateIndexes performs a createIndexes operation.
type CreateIndexes struct {
authenticator driver.Authenticator
commitQuorum bsoncore.Value
indexes bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
result CreateIndexesResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
rawData *bool
}
// CreateIndexesResult represents a createIndexes result returned by the server.
type CreateIndexesResult struct {
// If the collection was created automatically.
CreatedCollectionAutomatically bool
// The number of indexes existing after this command.
IndexesAfter int32
// The number of indexes existing before this command.
IndexesBefore int32
}
func buildCreateIndexesResult(response bsoncore.Document) (CreateIndexesResult, error) {
elements, err := response.Elements()
if err != nil {
return CreateIndexesResult{}, err
}
cir := CreateIndexesResult{}
for _, element := range elements {
switch element.Key() {
case "createdCollectionAutomatically":
var ok bool
cir.CreatedCollectionAutomatically, ok = element.Value().BooleanOK()
if !ok {
return cir, fmt.Errorf("response field 'createdCollectionAutomatically' is type bool, but received BSON type %s", element.Value().Type)
}
case "indexesAfter":
var ok bool
cir.IndexesAfter, ok = element.Value().AsInt32OK()
if !ok {
return cir, fmt.Errorf("response field 'indexesAfter' is type int32, but received BSON type %s", element.Value().Type)
}
case "indexesBefore":
var ok bool
cir.IndexesBefore, ok = element.Value().AsInt32OK()
if !ok {
return cir, fmt.Errorf("response field 'indexesBefore' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return cir, nil
}
// NewCreateIndexes constructs and returns a new CreateIndexes.
func NewCreateIndexes(indexes bsoncore.Document) *CreateIndexes {
return &CreateIndexes{
indexes: indexes,
}
}
// Result returns the result of executing this operation.
func (ci *CreateIndexes) Result() CreateIndexesResult { return ci.result }
func (ci *CreateIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
ci.result, err = buildCreateIndexesResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (ci *CreateIndexes) Execute(ctx context.Context) error {
if ci.deployment == nil {
return errors.New("the CreateIndexes operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: ci.command,
ProcessResponseFn: ci.processResponse,
Client: ci.session,
Clock: ci.clock,
CommandMonitor: ci.monitor,
Crypt: ci.crypt,
Database: ci.database,
Deployment: ci.deployment,
Selector: ci.selector,
WriteConcern: ci.writeConcern,
ServerAPI: ci.serverAPI,
Timeout: ci.timeout,
Name: driverutil.CreateIndexesOp,
Authenticator: ci.authenticator,
}.Execute(ctx)
}
func (ci *CreateIndexes) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "createIndexes", ci.collection)
if ci.commitQuorum.Type != bsoncore.Type(0) {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 9) {
return nil, errors.New("the 'commitQuorum' command parameter requires a minimum server wire version of 9")
}
dst = bsoncore.AppendValueElement(dst, "commitQuorum", ci.commitQuorum)
}
if ci.indexes != nil {
dst = bsoncore.AppendArrayElement(dst, "indexes", ci.indexes)
}
// Set rawData for 8.2+ servers.
if ci.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *ci.rawData)
}
return dst, nil
}
// CommitQuorum specifies the number of data-bearing members of a replica set, including the primary, that must
// complete the index builds successfully before the primary marks the indexes as ready. This should either be a
// string or int32 value.
func (ci *CreateIndexes) CommitQuorum(commitQuorum bsoncore.Value) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.commitQuorum = commitQuorum
return ci
}
// Indexes specifies an array containing index specification documents for the indexes being created.
func (ci *CreateIndexes) Indexes(indexes bsoncore.Document) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.indexes = indexes
return ci
}
// Session sets the session for this operation.
func (ci *CreateIndexes) Session(session *session.Client) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.session = session
return ci
}
// ClusterClock sets the cluster clock for this operation.
func (ci *CreateIndexes) ClusterClock(clock *session.ClusterClock) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.clock = clock
return ci
}
// Collection sets the collection that this command will run against.
func (ci *CreateIndexes) Collection(collection string) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.collection = collection
return ci
}
// CommandMonitor sets the monitor to use for APM events.
func (ci *CreateIndexes) CommandMonitor(monitor *event.CommandMonitor) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.monitor = monitor
return ci
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (ci *CreateIndexes) Crypt(crypt driver.Crypt) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.crypt = crypt
return ci
}
// Database sets the database to run this operation against.
func (ci *CreateIndexes) Database(database string) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.database = database
return ci
}
// Deployment sets the deployment to use for this operation.
func (ci *CreateIndexes) Deployment(deployment driver.Deployment) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.deployment = deployment
return ci
}
// ServerSelector sets the selector used to retrieve a server.
func (ci *CreateIndexes) ServerSelector(selector description.ServerSelector) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.selector = selector
return ci
}
// WriteConcern sets the write concern for this operation.
func (ci *CreateIndexes) WriteConcern(writeConcern *writeconcern.WriteConcern) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.writeConcern = writeConcern
return ci
}
// ServerAPI sets the server API version for this operation.
func (ci *CreateIndexes) ServerAPI(serverAPI *driver.ServerAPIOptions) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.serverAPI = serverAPI
return ci
}
// Timeout sets the timeout for this operation.
func (ci *CreateIndexes) Timeout(timeout *time.Duration) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.timeout = timeout
return ci
}
// Authenticator sets the authenticator to use for this operation.
func (ci *CreateIndexes) Authenticator(authenticator driver.Authenticator) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.authenticator = authenticator
return ci
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (ci *CreateIndexes) RawData(rawData bool) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.rawData = &rawData
return ci
}

View File

@@ -0,0 +1,250 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// CreateSearchIndexes performs a createSearchIndexes operation.
type CreateSearchIndexes struct {
authenticator driver.Authenticator
indexes bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
result CreateSearchIndexesResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
}
// CreateSearchIndexResult represents a single search index result in CreateSearchIndexesResult.
type CreateSearchIndexResult struct {
Name string
}
// CreateSearchIndexesResult represents a createSearchIndexes result returned by the server.
type CreateSearchIndexesResult struct {
IndexesCreated []CreateSearchIndexResult
}
func buildCreateSearchIndexesResult(response bsoncore.Document) (CreateSearchIndexesResult, error) {
elements, err := response.Elements()
if err != nil {
return CreateSearchIndexesResult{}, err
}
csir := CreateSearchIndexesResult{}
for _, element := range elements {
switch element.Key() {
case "indexesCreated":
arr, ok := element.Value().ArrayOK()
if !ok {
return csir, fmt.Errorf("response field 'indexesCreated' is type array, but received BSON type %s", element.Value().Type)
}
var values []bsoncore.Value
values, err = arr.Values()
if err != nil {
break
}
for _, val := range values {
valDoc, ok := val.DocumentOK()
if !ok {
return csir, fmt.Errorf("indexesCreated value is type document, but received BSON type %s", val.Type)
}
var indexesCreated CreateSearchIndexResult
if err = bson.Unmarshal(valDoc, &indexesCreated); err != nil {
return csir, err
}
csir.IndexesCreated = append(csir.IndexesCreated, indexesCreated)
}
}
}
return csir, nil
}
// NewCreateSearchIndexes constructs and returns a new CreateSearchIndexes.
func NewCreateSearchIndexes(indexes bsoncore.Document) *CreateSearchIndexes {
return &CreateSearchIndexes{
indexes: indexes,
}
}
// Result returns the result of executing this operation.
func (csi *CreateSearchIndexes) Result() CreateSearchIndexesResult { return csi.result }
func (csi *CreateSearchIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
csi.result, err = buildCreateSearchIndexesResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (csi *CreateSearchIndexes) Execute(ctx context.Context) error {
if csi.deployment == nil {
return errors.New("the CreateSearchIndexes operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: csi.command,
ProcessResponseFn: csi.processResponse,
Client: csi.session,
Clock: csi.clock,
CommandMonitor: csi.monitor,
Crypt: csi.crypt,
Database: csi.database,
Deployment: csi.deployment,
Selector: csi.selector,
ServerAPI: csi.serverAPI,
Timeout: csi.timeout,
Authenticator: csi.authenticator,
}.Execute(ctx)
}
func (csi *CreateSearchIndexes) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "createSearchIndexes", csi.collection)
if csi.indexes != nil {
dst = bsoncore.AppendArrayElement(dst, "indexes", csi.indexes)
}
return dst, nil
}
// Indexes specifies an array containing index specification documents for the indexes being created.
func (csi *CreateSearchIndexes) Indexes(indexes bsoncore.Document) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.indexes = indexes
return csi
}
// Session sets the session for this operation.
func (csi *CreateSearchIndexes) Session(session *session.Client) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.session = session
return csi
}
// ClusterClock sets the cluster clock for this operation.
func (csi *CreateSearchIndexes) ClusterClock(clock *session.ClusterClock) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.clock = clock
return csi
}
// Collection sets the collection that this command will run against.
func (csi *CreateSearchIndexes) Collection(collection string) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.collection = collection
return csi
}
// CommandMonitor sets the monitor to use for APM events.
func (csi *CreateSearchIndexes) CommandMonitor(monitor *event.CommandMonitor) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.monitor = monitor
return csi
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (csi *CreateSearchIndexes) Crypt(crypt driver.Crypt) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.crypt = crypt
return csi
}
// Database sets the database to run this operation against.
func (csi *CreateSearchIndexes) Database(database string) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.database = database
return csi
}
// Deployment sets the deployment to use for this operation.
func (csi *CreateSearchIndexes) Deployment(deployment driver.Deployment) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.deployment = deployment
return csi
}
// ServerSelector sets the selector used to retrieve a server.
func (csi *CreateSearchIndexes) ServerSelector(selector description.ServerSelector) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.selector = selector
return csi
}
// ServerAPI sets the server API version for this operation.
func (csi *CreateSearchIndexes) ServerAPI(serverAPI *driver.ServerAPIOptions) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.serverAPI = serverAPI
return csi
}
// Timeout sets the timeout for this operation.
func (csi *CreateSearchIndexes) Timeout(timeout *time.Duration) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.timeout = timeout
return csi
}
// Authenticator sets the authenticator to use for this operation.
func (csi *CreateSearchIndexes) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexes {
if csi == nil {
csi = new(CreateSearchIndexes)
}
csi.authenticator = authenticator
return csi
}

View File

@@ -0,0 +1,353 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/internal/logger"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Delete performs a delete operation
type Delete struct {
authenticator driver.Authenticator
comment bsoncore.Value
deletes []bsoncore.Document
ordered *bool
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
hint *bool
result DeleteResult
serverAPI *driver.ServerAPIOptions
let bsoncore.Document
timeout *time.Duration
rawData *bool
logger *logger.Logger
}
// DeleteResult represents a delete result returned by the server.
type DeleteResult struct {
// Number of documents successfully deleted.
N int64
}
func buildDeleteResult(response bsoncore.Document) (DeleteResult, error) {
elements, err := response.Elements()
if err != nil {
return DeleteResult{}, err
}
dr := DeleteResult{}
for _, element := range elements {
if element.Key() == "n" {
var ok bool
dr.N, ok = element.Value().AsInt64OK()
if !ok {
return dr, fmt.Errorf("response field 'n' is type int32 or int64, but received BSON type %s", element.Value().Type)
}
}
}
return dr, nil
}
// NewDelete constructs and returns a new Delete.
func NewDelete(deletes ...bsoncore.Document) *Delete {
return &Delete{
deletes: deletes,
}
}
// Result returns the result of executing this operation.
func (d *Delete) Result() DeleteResult { return d.result }
func (d *Delete) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
dr, err := buildDeleteResult(resp)
d.result.N += dr.N
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (d *Delete) Execute(ctx context.Context) error {
if d.deployment == nil {
return errors.New("the Delete operation must have a Deployment set before Execute can be called")
}
batches := &driver.Batches{
Identifier: "deletes",
Documents: d.deletes,
Ordered: d.ordered,
}
return driver.Operation{
CommandFn: d.command,
ProcessResponseFn: d.processResponse,
Batches: batches,
RetryMode: d.retry,
Type: driver.Write,
Client: d.session,
Clock: d.clock,
CommandMonitor: d.monitor,
Crypt: d.crypt,
Database: d.database,
Deployment: d.deployment,
Selector: d.selector,
WriteConcern: d.writeConcern,
ServerAPI: d.serverAPI,
Timeout: d.timeout,
Logger: d.logger,
Name: driverutil.DeleteOp,
Authenticator: d.authenticator,
}.Execute(ctx)
}
func (d *Delete) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "delete", d.collection)
if d.comment.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "comment", d.comment)
}
if d.ordered != nil {
dst = bsoncore.AppendBooleanElement(dst, "ordered", *d.ordered)
}
if d.hint != nil && *d.hint {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) {
return nil, errors.New("the 'hint' command parameter requires a minimum server wire version of 5")
}
if !d.writeConcern.Acknowledged() {
return nil, errUnacknowledgedHint
}
}
if d.let != nil {
dst = bsoncore.AppendDocumentElement(dst, "let", d.let)
}
// Set rawData for 8.2+ servers.
if d.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *d.rawData)
}
return dst, nil
}
// Deletes adds documents to this operation that will be used to determine what documents to delete when this operation
// is executed. These documents should have the form {q: <query>, limit: <integer limit>, collation: <document>}. The
// collation field is optional. If limit is 0, there will be no limit on the number of documents deleted.
func (d *Delete) Deletes(deletes ...bsoncore.Document) *Delete {
if d == nil {
d = new(Delete)
}
d.deletes = deletes
return d
}
// Ordered sets ordered. If true, when a write fails, the operation will return the error, when
// false write failures do not stop execution of the operation.
func (d *Delete) Ordered(ordered bool) *Delete {
if d == nil {
d = new(Delete)
}
d.ordered = &ordered
return d
}
// Session sets the session for this operation.
func (d *Delete) Session(session *session.Client) *Delete {
if d == nil {
d = new(Delete)
}
d.session = session
return d
}
// ClusterClock sets the cluster clock for this operation.
func (d *Delete) ClusterClock(clock *session.ClusterClock) *Delete {
if d == nil {
d = new(Delete)
}
d.clock = clock
return d
}
// Collection sets the collection that this command will run against.
func (d *Delete) Collection(collection string) *Delete {
if d == nil {
d = new(Delete)
}
d.collection = collection
return d
}
// Comment sets a value to help trace an operation.
func (d *Delete) Comment(comment bsoncore.Value) *Delete {
if d == nil {
d = new(Delete)
}
d.comment = comment
return d
}
// CommandMonitor sets the monitor to use for APM events.
func (d *Delete) CommandMonitor(monitor *event.CommandMonitor) *Delete {
if d == nil {
d = new(Delete)
}
d.monitor = monitor
return d
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (d *Delete) Crypt(crypt driver.Crypt) *Delete {
if d == nil {
d = new(Delete)
}
d.crypt = crypt
return d
}
// Database sets the database to run this operation against.
func (d *Delete) Database(database string) *Delete {
if d == nil {
d = new(Delete)
}
d.database = database
return d
}
// Deployment sets the deployment to use for this operation.
func (d *Delete) Deployment(deployment driver.Deployment) *Delete {
if d == nil {
d = new(Delete)
}
d.deployment = deployment
return d
}
// ServerSelector sets the selector used to retrieve a server.
func (d *Delete) ServerSelector(selector description.ServerSelector) *Delete {
if d == nil {
d = new(Delete)
}
d.selector = selector
return d
}
// WriteConcern sets the write concern for this operation.
func (d *Delete) WriteConcern(writeConcern *writeconcern.WriteConcern) *Delete {
if d == nil {
d = new(Delete)
}
d.writeConcern = writeConcern
return d
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (d *Delete) Retry(retry driver.RetryMode) *Delete {
if d == nil {
d = new(Delete)
}
d.retry = &retry
return d
}
// Hint is a flag to indicate that the update document contains a hint. Hint is only supported by
// servers >= 4.4. Older servers will report an error for using the hint option.
func (d *Delete) Hint(hint bool) *Delete {
if d == nil {
d = new(Delete)
}
d.hint = &hint
return d
}
// ServerAPI sets the server API version for this operation.
func (d *Delete) ServerAPI(serverAPI *driver.ServerAPIOptions) *Delete {
if d == nil {
d = new(Delete)
}
d.serverAPI = serverAPI
return d
}
// Let specifies the let document to use. This option is only valid for server versions 5.0 and above.
func (d *Delete) Let(let bsoncore.Document) *Delete {
if d == nil {
d = new(Delete)
}
d.let = let
return d
}
// Timeout sets the timeout for this operation.
func (d *Delete) Timeout(timeout *time.Duration) *Delete {
if d == nil {
d = new(Delete)
}
d.timeout = timeout
return d
}
// Logger sets the logger for this operation.
func (d *Delete) Logger(logger *logger.Logger) *Delete {
if d == nil {
d = new(Delete)
}
d.logger = logger
return d
}
// Authenticator sets the authenticator to use for this operation.
func (d *Delete) Authenticator(authenticator driver.Authenticator) *Delete {
if d == nil {
d = new(Delete)
}
d.authenticator = authenticator
return d
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (d *Delete) RawData(rawData bool) *Delete {
if d == nil {
d = new(Delete)
}
d.rawData = &rawData
return d
}

View File

@@ -0,0 +1,339 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Distinct performs a distinct operation.
type Distinct struct {
authenticator driver.Authenticator
collation bsoncore.Document
key *string
query bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
comment bsoncore.Value
hint bsoncore.Value
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
readConcern *readconcern.ReadConcern
readPreference *readpref.ReadPref
selector description.ServerSelector
retry *driver.RetryMode
result DistinctResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
rawData *bool
}
// DistinctResult represents a distinct result returned by the server.
type DistinctResult struct {
// The distinct values for the field.
Values bsoncore.Value
}
func buildDistinctResult(response bsoncore.Document) (DistinctResult, error) {
elements, err := response.Elements()
if err != nil {
return DistinctResult{}, err
}
dr := DistinctResult{}
for _, element := range elements {
if element.Key() == "values" {
dr.Values = element.Value()
}
}
return dr, nil
}
// NewDistinct constructs and returns a new Distinct.
func NewDistinct(key string, query bsoncore.Document) *Distinct {
return &Distinct{
key: &key,
query: query,
}
}
// Result returns the result of executing this operation.
func (d *Distinct) Result() DistinctResult { return d.result }
func (d *Distinct) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
d.result, err = buildDistinctResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (d *Distinct) Execute(ctx context.Context) error {
if d.deployment == nil {
return errors.New("the Distinct operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: d.command,
ProcessResponseFn: d.processResponse,
RetryMode: d.retry,
Type: driver.Read,
Client: d.session,
Clock: d.clock,
CommandMonitor: d.monitor,
Crypt: d.crypt,
Database: d.database,
Deployment: d.deployment,
ReadConcern: d.readConcern,
ReadPreference: d.readPreference,
Selector: d.selector,
ServerAPI: d.serverAPI,
Timeout: d.timeout,
Name: driverutil.DistinctOp,
Authenticator: d.authenticator,
}.Execute(ctx)
}
func (d *Distinct) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "distinct", d.collection)
if d.collation != nil {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", d.collation)
}
if d.comment.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "comment", d.comment)
}
if d.hint.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "hint", d.hint)
}
if d.key != nil {
dst = bsoncore.AppendStringElement(dst, "key", *d.key)
}
if d.query != nil {
dst = bsoncore.AppendDocumentElement(dst, "query", d.query)
}
// Set rawData for 8.2+ servers.
if d.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *d.rawData)
}
return dst, nil
}
// Collation specifies a collation to be used.
func (d *Distinct) Collation(collation bsoncore.Document) *Distinct {
if d == nil {
d = new(Distinct)
}
d.collation = collation
return d
}
// Key specifies which field to return distinct values for.
func (d *Distinct) Key(key string) *Distinct {
if d == nil {
d = new(Distinct)
}
d.key = &key
return d
}
// Query specifies which documents to return distinct values from.
func (d *Distinct) Query(query bsoncore.Document) *Distinct {
if d == nil {
d = new(Distinct)
}
d.query = query
return d
}
// Session sets the session for this operation.
func (d *Distinct) Session(session *session.Client) *Distinct {
if d == nil {
d = new(Distinct)
}
d.session = session
return d
}
// ClusterClock sets the cluster clock for this operation.
func (d *Distinct) ClusterClock(clock *session.ClusterClock) *Distinct {
if d == nil {
d = new(Distinct)
}
d.clock = clock
return d
}
// Collection sets the collection that this command will run against.
func (d *Distinct) Collection(collection string) *Distinct {
if d == nil {
d = new(Distinct)
}
d.collection = collection
return d
}
// Comment sets a value to help trace an operation.
func (d *Distinct) Comment(comment bsoncore.Value) *Distinct {
if d == nil {
d = new(Distinct)
}
d.comment = comment
return d
}
// Hint sets a value to help trace an operation.
func (d *Distinct) Hint(hint bsoncore.Value) *Distinct {
if d == nil {
d = new(Distinct)
}
d.hint = hint
return d
}
// CommandMonitor sets the monitor to use for APM events.
func (d *Distinct) CommandMonitor(monitor *event.CommandMonitor) *Distinct {
if d == nil {
d = new(Distinct)
}
d.monitor = monitor
return d
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (d *Distinct) Crypt(crypt driver.Crypt) *Distinct {
if d == nil {
d = new(Distinct)
}
d.crypt = crypt
return d
}
// Database sets the database to run this operation against.
func (d *Distinct) Database(database string) *Distinct {
if d == nil {
d = new(Distinct)
}
d.database = database
return d
}
// Deployment sets the deployment to use for this operation.
func (d *Distinct) Deployment(deployment driver.Deployment) *Distinct {
if d == nil {
d = new(Distinct)
}
d.deployment = deployment
return d
}
// ReadConcern specifies the read concern for this operation.
func (d *Distinct) ReadConcern(readConcern *readconcern.ReadConcern) *Distinct {
if d == nil {
d = new(Distinct)
}
d.readConcern = readConcern
return d
}
// ReadPreference set the read preference used with this operation.
func (d *Distinct) ReadPreference(readPreference *readpref.ReadPref) *Distinct {
if d == nil {
d = new(Distinct)
}
d.readPreference = readPreference
return d
}
// ServerSelector sets the selector used to retrieve a server.
func (d *Distinct) ServerSelector(selector description.ServerSelector) *Distinct {
if d == nil {
d = new(Distinct)
}
d.selector = selector
return d
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (d *Distinct) Retry(retry driver.RetryMode) *Distinct {
if d == nil {
d = new(Distinct)
}
d.retry = &retry
return d
}
// ServerAPI sets the server API version for this operation.
func (d *Distinct) ServerAPI(serverAPI *driver.ServerAPIOptions) *Distinct {
if d == nil {
d = new(Distinct)
}
d.serverAPI = serverAPI
return d
}
// Timeout sets the timeout for this operation.
func (d *Distinct) Timeout(timeout *time.Duration) *Distinct {
if d == nil {
d = new(Distinct)
}
d.timeout = timeout
return d
}
// Authenticator sets the authenticator to use for this operation.
func (d *Distinct) Authenticator(authenticator driver.Authenticator) *Distinct {
if d == nil {
d = new(Distinct)
}
d.authenticator = authenticator
return d
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (d *Distinct) RawData(rawData bool) *Distinct {
if d == nil {
d = new(Distinct)
}
d.rawData = &rawData
return d
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package operation is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package operation

View File

@@ -0,0 +1,235 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// DropCollection performs a drop operation.
type DropCollection struct {
authenticator driver.Authenticator
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
result DropCollectionResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
}
// DropCollectionResult represents a dropCollection result returned by the server.
type DropCollectionResult struct {
// The number of indexes in the dropped collection.
NIndexesWas int32
// The namespace of the dropped collection.
Ns string
}
func buildDropCollectionResult(response bsoncore.Document) (DropCollectionResult, error) {
elements, err := response.Elements()
if err != nil {
return DropCollectionResult{}, err
}
dcr := DropCollectionResult{}
for _, element := range elements {
switch element.Key() {
case "nIndexesWas":
var ok bool
dcr.NIndexesWas, ok = element.Value().AsInt32OK()
if !ok {
return dcr, fmt.Errorf("response field 'nIndexesWas' is type int32, but received BSON type %s", element.Value().Type)
}
case "ns":
var ok bool
dcr.Ns, ok = element.Value().StringValueOK()
if !ok {
return dcr, fmt.Errorf("response field 'ns' is type string, but received BSON type %s", element.Value().Type)
}
}
}
return dcr, nil
}
// NewDropCollection constructs and returns a new DropCollection.
func NewDropCollection() *DropCollection {
return &DropCollection{}
}
// Result returns the result of executing this operation.
func (dc *DropCollection) Result() DropCollectionResult { return dc.result }
func (dc *DropCollection) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
dc.result, err = buildDropCollectionResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (dc *DropCollection) Execute(ctx context.Context) error {
if dc.deployment == nil {
return errors.New("the DropCollection operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: dc.command,
ProcessResponseFn: dc.processResponse,
Client: dc.session,
Clock: dc.clock,
CommandMonitor: dc.monitor,
Crypt: dc.crypt,
Database: dc.database,
Deployment: dc.deployment,
Selector: dc.selector,
WriteConcern: dc.writeConcern,
ServerAPI: dc.serverAPI,
Timeout: dc.timeout,
Name: driverutil.DropOp,
Authenticator: dc.authenticator,
}.Execute(ctx)
}
func (dc *DropCollection) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "drop", dc.collection)
return dst, nil
}
// Session sets the session for this operation.
func (dc *DropCollection) Session(session *session.Client) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.session = session
return dc
}
// ClusterClock sets the cluster clock for this operation.
func (dc *DropCollection) ClusterClock(clock *session.ClusterClock) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.clock = clock
return dc
}
// Collection sets the collection that this command will run against.
func (dc *DropCollection) Collection(collection string) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.collection = collection
return dc
}
// CommandMonitor sets the monitor to use for APM events.
func (dc *DropCollection) CommandMonitor(monitor *event.CommandMonitor) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.monitor = monitor
return dc
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (dc *DropCollection) Crypt(crypt driver.Crypt) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.crypt = crypt
return dc
}
// Database sets the database to run this operation against.
func (dc *DropCollection) Database(database string) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.database = database
return dc
}
// Deployment sets the deployment to use for this operation.
func (dc *DropCollection) Deployment(deployment driver.Deployment) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.deployment = deployment
return dc
}
// ServerSelector sets the selector used to retrieve a server.
func (dc *DropCollection) ServerSelector(selector description.ServerSelector) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.selector = selector
return dc
}
// WriteConcern sets the write concern for this operation.
func (dc *DropCollection) WriteConcern(writeConcern *writeconcern.WriteConcern) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.writeConcern = writeConcern
return dc
}
// ServerAPI sets the server API version for this operation.
func (dc *DropCollection) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.serverAPI = serverAPI
return dc
}
// Timeout sets the timeout for this operation.
func (dc *DropCollection) Timeout(timeout *time.Duration) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.timeout = timeout
return dc
}
// Authenticator sets the authenticator to use for this operation.
func (dc *DropCollection) Authenticator(authenticator driver.Authenticator) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.authenticator = authenticator
return dc
}

View File

@@ -0,0 +1,166 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// DropDatabase performs a dropDatabase operation
type DropDatabase struct {
authenticator driver.Authenticator
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
serverAPI *driver.ServerAPIOptions
}
// NewDropDatabase constructs and returns a new DropDatabase.
func NewDropDatabase() *DropDatabase {
return &DropDatabase{}
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (dd *DropDatabase) Execute(ctx context.Context) error {
if dd.deployment == nil {
return errors.New("the DropDatabase operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: dd.command,
Client: dd.session,
Clock: dd.clock,
CommandMonitor: dd.monitor,
Crypt: dd.crypt,
Database: dd.database,
Deployment: dd.deployment,
Selector: dd.selector,
WriteConcern: dd.writeConcern,
ServerAPI: dd.serverAPI,
Name: driverutil.DropDatabaseOp,
Authenticator: dd.authenticator,
}.Execute(ctx)
}
func (dd *DropDatabase) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "dropDatabase", 1)
return dst, nil
}
// Session sets the session for this operation.
func (dd *DropDatabase) Session(session *session.Client) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.session = session
return dd
}
// ClusterClock sets the cluster clock for this operation.
func (dd *DropDatabase) ClusterClock(clock *session.ClusterClock) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.clock = clock
return dd
}
// CommandMonitor sets the monitor to use for APM events.
func (dd *DropDatabase) CommandMonitor(monitor *event.CommandMonitor) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.monitor = monitor
return dd
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (dd *DropDatabase) Crypt(crypt driver.Crypt) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.crypt = crypt
return dd
}
// Database sets the database to run this operation against.
func (dd *DropDatabase) Database(database string) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.database = database
return dd
}
// Deployment sets the deployment to use for this operation.
func (dd *DropDatabase) Deployment(deployment driver.Deployment) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.deployment = deployment
return dd
}
// ServerSelector sets the selector used to retrieve a server.
func (dd *DropDatabase) ServerSelector(selector description.ServerSelector) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.selector = selector
return dd
}
// WriteConcern sets the write concern for this operation.
func (dd *DropDatabase) WriteConcern(writeConcern *writeconcern.WriteConcern) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.writeConcern = writeConcern
return dd
}
// ServerAPI sets the server API version for this operation.
func (dd *DropDatabase) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.serverAPI = serverAPI
return dd
}
// Authenticator sets the authenticator to use for this operation.
func (dd *DropDatabase) Authenticator(authenticator driver.Authenticator) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.authenticator = authenticator
return dd
}

View File

@@ -0,0 +1,264 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// DropIndexes performs an dropIndexes operation.
type DropIndexes struct {
authenticator driver.Authenticator
index any
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
result DropIndexesResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
rawData *bool
}
// DropIndexesResult represents a dropIndexes result returned by the server.
type DropIndexesResult struct {
// Number of indexes that existed before the drop was executed.
NIndexesWas int32
}
func buildDropIndexesResult(response bsoncore.Document) (DropIndexesResult, error) {
elements, err := response.Elements()
if err != nil {
return DropIndexesResult{}, err
}
dir := DropIndexesResult{}
for _, element := range elements {
if element.Key() == "nIndexesWas" {
var ok bool
dir.NIndexesWas, ok = element.Value().AsInt32OK()
if !ok {
return dir, fmt.Errorf("response field 'nIndexesWas' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return dir, nil
}
// NewDropIndexes constructs and returns a new DropIndexes.
func NewDropIndexes(index any) *DropIndexes {
return &DropIndexes{
index: index,
}
}
// Result returns the result of executing this operation.
func (di *DropIndexes) Result() DropIndexesResult { return di.result }
func (di *DropIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
di.result, err = buildDropIndexesResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (di *DropIndexes) Execute(ctx context.Context) error {
if di.deployment == nil {
return errors.New("the DropIndexes operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: di.command,
ProcessResponseFn: di.processResponse,
Client: di.session,
Clock: di.clock,
CommandMonitor: di.monitor,
Crypt: di.crypt,
Database: di.database,
Deployment: di.deployment,
Selector: di.selector,
WriteConcern: di.writeConcern,
ServerAPI: di.serverAPI,
Timeout: di.timeout,
Name: driverutil.DropIndexesOp,
Authenticator: di.authenticator,
}.Execute(ctx)
}
func (di *DropIndexes) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "dropIndexes", di.collection)
switch t := di.index.(type) {
case string:
dst = bsoncore.AppendStringElement(dst, "index", t)
case bsoncore.Document:
if di.index != nil {
dst = bsoncore.AppendDocumentElement(dst, "index", t)
}
}
// Set rawData for 8.2+ servers.
if di.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *di.rawData)
}
return dst, nil
}
// Index specifies the name of the index to drop. If '*' is specified, all indexes will be dropped.
func (di *DropIndexes) Index(index any) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.index = index
return di
}
// Session sets the session for this operation.
func (di *DropIndexes) Session(session *session.Client) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.session = session
return di
}
// ClusterClock sets the cluster clock for this operation.
func (di *DropIndexes) ClusterClock(clock *session.ClusterClock) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.clock = clock
return di
}
// Collection sets the collection that this command will run against.
func (di *DropIndexes) Collection(collection string) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.collection = collection
return di
}
// CommandMonitor sets the monitor to use for APM events.
func (di *DropIndexes) CommandMonitor(monitor *event.CommandMonitor) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.monitor = monitor
return di
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (di *DropIndexes) Crypt(crypt driver.Crypt) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.crypt = crypt
return di
}
// Database sets the database to run this operation against.
func (di *DropIndexes) Database(database string) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.database = database
return di
}
// Deployment sets the deployment to use for this operation.
func (di *DropIndexes) Deployment(deployment driver.Deployment) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.deployment = deployment
return di
}
// ServerSelector sets the selector used to retrieve a server.
func (di *DropIndexes) ServerSelector(selector description.ServerSelector) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.selector = selector
return di
}
// WriteConcern sets the write concern for this operation.
func (di *DropIndexes) WriteConcern(writeConcern *writeconcern.WriteConcern) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.writeConcern = writeConcern
return di
}
// ServerAPI sets the server API version for this operation.
func (di *DropIndexes) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.serverAPI = serverAPI
return di
}
// Timeout sets the timeout for this operation.
func (di *DropIndexes) Timeout(timeout *time.Duration) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.timeout = timeout
return di
}
// Authenticator sets the authenticator to use for this operation.
func (di *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.authenticator = authenticator
return di
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (di *DropIndexes) RawData(rawData bool) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.rawData = &rawData
return di
}

View File

@@ -0,0 +1,224 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// DropSearchIndex performs an dropSearchIndex operation.
type DropSearchIndex struct {
authenticator driver.Authenticator
index string
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
result DropSearchIndexResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
}
// DropSearchIndexResult represents a dropSearchIndex result returned by the server.
type DropSearchIndexResult struct {
Ok int32
}
func buildDropSearchIndexResult(response bsoncore.Document) (DropSearchIndexResult, error) {
elements, err := response.Elements()
if err != nil {
return DropSearchIndexResult{}, err
}
dsir := DropSearchIndexResult{}
for _, element := range elements {
if element.Key() == "ok" {
var ok bool
dsir.Ok, ok = element.Value().AsInt32OK()
if !ok {
return dsir, fmt.Errorf("response field 'ok' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return dsir, nil
}
// NewDropSearchIndex constructs and returns a new DropSearchIndex.
func NewDropSearchIndex(index string) *DropSearchIndex {
return &DropSearchIndex{
index: index,
}
}
// Result returns the result of executing this operation.
func (dsi *DropSearchIndex) Result() DropSearchIndexResult { return dsi.result }
func (dsi *DropSearchIndex) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
dsi.result, err = buildDropSearchIndexResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (dsi *DropSearchIndex) Execute(ctx context.Context) error {
if dsi.deployment == nil {
return errors.New("the DropSearchIndex operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: dsi.command,
ProcessResponseFn: dsi.processResponse,
Client: dsi.session,
Clock: dsi.clock,
CommandMonitor: dsi.monitor,
Crypt: dsi.crypt,
Database: dsi.database,
Deployment: dsi.deployment,
Selector: dsi.selector,
ServerAPI: dsi.serverAPI,
Timeout: dsi.timeout,
Authenticator: dsi.authenticator,
}.Execute(ctx)
}
func (dsi *DropSearchIndex) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "dropSearchIndex", dsi.collection)
dst = bsoncore.AppendStringElement(dst, "name", dsi.index)
return dst, nil
}
// Index specifies the name of the index to drop. If '*' is specified, all indexes will be dropped.
func (dsi *DropSearchIndex) Index(index string) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.index = index
return dsi
}
// Session sets the session for this operation.
func (dsi *DropSearchIndex) Session(session *session.Client) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.session = session
return dsi
}
// ClusterClock sets the cluster clock for this operation.
func (dsi *DropSearchIndex) ClusterClock(clock *session.ClusterClock) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.clock = clock
return dsi
}
// Collection sets the collection that this command will run against.
func (dsi *DropSearchIndex) Collection(collection string) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.collection = collection
return dsi
}
// CommandMonitor sets the monitor to use for APM events.
func (dsi *DropSearchIndex) CommandMonitor(monitor *event.CommandMonitor) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.monitor = monitor
return dsi
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (dsi *DropSearchIndex) Crypt(crypt driver.Crypt) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.crypt = crypt
return dsi
}
// Database sets the database to run this operation against.
func (dsi *DropSearchIndex) Database(database string) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.database = database
return dsi
}
// Deployment sets the deployment to use for this operation.
func (dsi *DropSearchIndex) Deployment(deployment driver.Deployment) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.deployment = deployment
return dsi
}
// ServerSelector sets the selector used to retrieve a server.
func (dsi *DropSearchIndex) ServerSelector(selector description.ServerSelector) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.selector = selector
return dsi
}
// ServerAPI sets the server API version for this operation.
func (dsi *DropSearchIndex) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.serverAPI = serverAPI
return dsi
}
// Timeout sets the timeout for this operation.
func (dsi *DropSearchIndex) Timeout(timeout *time.Duration) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.timeout = timeout
return dsi
}
// Authenticator sets the authenticator to use for this operation.
func (dsi *DropSearchIndex) Authenticator(authenticator driver.Authenticator) *DropSearchIndex {
if dsi == nil {
dsi = new(DropSearchIndex)
}
dsi.authenticator = authenticator
return dsi
}

View File

@@ -0,0 +1,173 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// EndSessions performs an endSessions operation.
type EndSessions struct {
authenticator driver.Authenticator
sessionIDs bsoncore.Document
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
serverAPI *driver.ServerAPIOptions
}
// NewEndSessions constructs and returns a new EndSessions.
func NewEndSessions(sessionIDs bsoncore.Document) *EndSessions {
return &EndSessions{
sessionIDs: sessionIDs,
}
}
func (es *EndSessions) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error {
return nil
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (es *EndSessions) Execute(ctx context.Context) error {
if es.deployment == nil {
return errors.New("the EndSessions operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: es.command,
ProcessResponseFn: es.processResponse,
Client: es.session,
Clock: es.clock,
CommandMonitor: es.monitor,
Crypt: es.crypt,
Database: es.database,
Deployment: es.deployment,
Selector: es.selector,
ServerAPI: es.serverAPI,
Name: driverutil.EndSessionsOp,
Authenticator: es.authenticator,
}.Execute(ctx)
}
func (es *EndSessions) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
if es.sessionIDs != nil {
dst = bsoncore.AppendArrayElement(dst, "endSessions", es.sessionIDs)
}
return dst, nil
}
// SessionIDs specifies the sessions to be expired.
func (es *EndSessions) SessionIDs(sessionIDs bsoncore.Document) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.sessionIDs = sessionIDs
return es
}
// Session sets the session for this operation.
func (es *EndSessions) Session(session *session.Client) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.session = session
return es
}
// ClusterClock sets the cluster clock for this operation.
func (es *EndSessions) ClusterClock(clock *session.ClusterClock) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.clock = clock
return es
}
// CommandMonitor sets the monitor to use for APM events.
func (es *EndSessions) CommandMonitor(monitor *event.CommandMonitor) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.monitor = monitor
return es
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (es *EndSessions) Crypt(crypt driver.Crypt) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.crypt = crypt
return es
}
// Database sets the database to run this operation against.
func (es *EndSessions) Database(database string) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.database = database
return es
}
// Deployment sets the deployment to use for this operation.
func (es *EndSessions) Deployment(deployment driver.Deployment) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.deployment = deployment
return es
}
// ServerSelector sets the selector used to retrieve a server.
func (es *EndSessions) ServerSelector(selector description.ServerSelector) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.selector = selector
return es
}
// ServerAPI sets the server API version for this operation.
func (es *EndSessions) ServerAPI(serverAPI *driver.ServerAPIOptions) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.serverAPI = serverAPI
return es
}
// Authenticator sets the authenticator to use for this operation.
func (es *EndSessions) Authenticator(authenticator driver.Authenticator) *EndSessions {
if es == nil {
es = new(EndSessions)
}
es.authenticator = authenticator
return es
}

View File

@@ -0,0 +1,11 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import "errors"
var errUnacknowledgedHint = errors.New("the 'hint' command parameter cannot be used with unacknowledged writes")

View File

@@ -0,0 +1,592 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/internal/logger"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Find performs a find operation.
type Find struct {
authenticator driver.Authenticator
allowDiskUse *bool
allowPartialResults *bool
awaitData *bool
batchSize *int32
collation bsoncore.Document
comment bsoncore.Value
filter bsoncore.Document
hint bsoncore.Value
let bsoncore.Document
limit *int64
max bsoncore.Document
min bsoncore.Document
noCursorTimeout *bool
oplogReplay *bool
projection bsoncore.Document
returnKey *bool
showRecordID *bool
singleBatch *bool
skip *int64
snapshot *bool
sort bsoncore.Document
tailable *bool
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
readConcern *readconcern.ReadConcern
readPreference *readpref.ReadPref
selector description.ServerSelector
retry *driver.RetryMode
result driver.CursorResponse
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
rawData *bool
logger *logger.Logger
omitMaxTimeMS bool
}
// NewFind constructs and returns a new Find.
func NewFind(filter bsoncore.Document) *Find {
return &Find{
filter: filter,
}
}
// Result returns the result of executing this operation.
func (f *Find) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) {
opts.ServerAPI = f.serverAPI
return driver.NewBatchCursor(f.result, f.session, f.clock, opts)
}
func (f *Find) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error {
curDoc, err := driver.ExtractCursorDocument(resp)
if err != nil {
return err
}
f.result, err = driver.NewCursorResponse(curDoc, info)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (f *Find) Execute(ctx context.Context) error {
if f.deployment == nil {
return errors.New("the Find operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: f.command,
ProcessResponseFn: f.processResponse,
RetryMode: f.retry,
Type: driver.Read,
Client: f.session,
Clock: f.clock,
CommandMonitor: f.monitor,
Crypt: f.crypt,
Database: f.database,
Deployment: f.deployment,
ReadConcern: f.readConcern,
ReadPreference: f.readPreference,
Selector: f.selector,
Legacy: driver.LegacyFind,
ServerAPI: f.serverAPI,
Timeout: f.timeout,
Logger: f.logger,
Name: driverutil.FindOp,
Authenticator: f.authenticator,
OmitMaxTimeMS: f.omitMaxTimeMS,
}.Execute(ctx)
}
func (f *Find) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "find", f.collection)
if f.allowDiskUse != nil {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 4) {
return nil, errors.New("the 'allowDiskUse' command parameter requires a minimum server wire version of 4")
}
dst = bsoncore.AppendBooleanElement(dst, "allowDiskUse", *f.allowDiskUse)
}
if f.allowPartialResults != nil {
dst = bsoncore.AppendBooleanElement(dst, "allowPartialResults", *f.allowPartialResults)
}
if f.awaitData != nil {
dst = bsoncore.AppendBooleanElement(dst, "awaitData", *f.awaitData)
}
if f.batchSize != nil {
dst = bsoncore.AppendInt32Element(dst, "batchSize", *f.batchSize)
}
if f.collation != nil {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", f.collation)
}
if f.comment.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "comment", f.comment)
}
if f.filter != nil {
dst = bsoncore.AppendDocumentElement(dst, "filter", f.filter)
}
if f.hint.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "hint", f.hint)
}
if f.let != nil {
dst = bsoncore.AppendDocumentElement(dst, "let", f.let)
}
if f.limit != nil {
dst = bsoncore.AppendInt64Element(dst, "limit", *f.limit)
}
if f.max != nil {
dst = bsoncore.AppendDocumentElement(dst, "max", f.max)
}
if f.min != nil {
dst = bsoncore.AppendDocumentElement(dst, "min", f.min)
}
if f.noCursorTimeout != nil {
dst = bsoncore.AppendBooleanElement(dst, "noCursorTimeout", *f.noCursorTimeout)
}
if f.oplogReplay != nil {
dst = bsoncore.AppendBooleanElement(dst, "oplogReplay", *f.oplogReplay)
}
if f.projection != nil {
dst = bsoncore.AppendDocumentElement(dst, "projection", f.projection)
}
if f.returnKey != nil {
dst = bsoncore.AppendBooleanElement(dst, "returnKey", *f.returnKey)
}
if f.showRecordID != nil {
dst = bsoncore.AppendBooleanElement(dst, "showRecordId", *f.showRecordID)
}
if f.singleBatch != nil {
dst = bsoncore.AppendBooleanElement(dst, "singleBatch", *f.singleBatch)
}
if f.skip != nil {
dst = bsoncore.AppendInt64Element(dst, "skip", *f.skip)
}
if f.snapshot != nil {
dst = bsoncore.AppendBooleanElement(dst, "snapshot", *f.snapshot)
}
if f.sort != nil {
dst = bsoncore.AppendDocumentElement(dst, "sort", f.sort)
}
if f.tailable != nil {
dst = bsoncore.AppendBooleanElement(dst, "tailable", *f.tailable)
}
// Set rawData for 8.2+ servers.
if f.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *f.rawData)
}
return dst, nil
}
// AllowDiskUse when true allows temporary data to be written to disk during the find command."
func (f *Find) AllowDiskUse(allowDiskUse bool) *Find {
if f == nil {
f = new(Find)
}
f.allowDiskUse = &allowDiskUse
return f
}
// AllowPartialResults when true allows partial results to be returned if some shards are down.
func (f *Find) AllowPartialResults(allowPartialResults bool) *Find {
if f == nil {
f = new(Find)
}
f.allowPartialResults = &allowPartialResults
return f
}
// AwaitData when true makes a cursor block before returning when no data is available.
func (f *Find) AwaitData(awaitData bool) *Find {
if f == nil {
f = new(Find)
}
f.awaitData = &awaitData
return f
}
// BatchSize specifies the number of documents to return in every batch.
func (f *Find) BatchSize(batchSize int32) *Find {
if f == nil {
f = new(Find)
}
f.batchSize = &batchSize
return f
}
// Collation specifies a collation to be used.
func (f *Find) Collation(collation bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.collation = collation
return f
}
// Comment sets a value to help trace an operation.
func (f *Find) Comment(comment bsoncore.Value) *Find {
if f == nil {
f = new(Find)
}
f.comment = comment
return f
}
// Filter determines what results are returned from find.
func (f *Find) Filter(filter bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.filter = filter
return f
}
// Hint specifies the index to use.
func (f *Find) Hint(hint bsoncore.Value) *Find {
if f == nil {
f = new(Find)
}
f.hint = hint
return f
}
// Let specifies the let document to use. This option is only valid for server versions 5.0 and above.
func (f *Find) Let(let bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.let = let
return f
}
// Limit sets a limit on the number of documents to return.
func (f *Find) Limit(limit int64) *Find {
if f == nil {
f = new(Find)
}
f.limit = &limit
return f
}
// Max sets an exclusive upper bound for a specific index.
func (f *Find) Max(max bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.max = max
return f
}
// Min sets an inclusive lower bound for a specific index.
func (f *Find) Min(min bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.min = min
return f
}
// NoCursorTimeout when true prevents cursor from timing out after an inactivity period.
func (f *Find) NoCursorTimeout(noCursorTimeout bool) *Find {
if f == nil {
f = new(Find)
}
f.noCursorTimeout = &noCursorTimeout
return f
}
// OplogReplay when true replays a replica set's oplog.
func (f *Find) OplogReplay(oplogReplay bool) *Find {
if f == nil {
f = new(Find)
}
f.oplogReplay = &oplogReplay
return f
}
// Projection limits the fields returned for all documents.
func (f *Find) Projection(projection bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.projection = projection
return f
}
// ReturnKey when true returns index keys for all result documents.
func (f *Find) ReturnKey(returnKey bool) *Find {
if f == nil {
f = new(Find)
}
f.returnKey = &returnKey
return f
}
// ShowRecordID when true adds a $recordId field with the record identifier to returned documents.
func (f *Find) ShowRecordID(showRecordID bool) *Find {
if f == nil {
f = new(Find)
}
f.showRecordID = &showRecordID
return f
}
// SingleBatch specifies whether the results should be returned in a single batch.
func (f *Find) SingleBatch(singleBatch bool) *Find {
if f == nil {
f = new(Find)
}
f.singleBatch = &singleBatch
return f
}
// Skip specifies the number of documents to skip before returning.
func (f *Find) Skip(skip int64) *Find {
if f == nil {
f = new(Find)
}
f.skip = &skip
return f
}
// Snapshot prevents the cursor from returning a document more than once because of an intervening write operation.
func (f *Find) Snapshot(snapshot bool) *Find {
if f == nil {
f = new(Find)
}
f.snapshot = &snapshot
return f
}
// Sort specifies the order in which to return results.
func (f *Find) Sort(sort bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.sort = sort
return f
}
// Tailable keeps a cursor open and resumable after the last data has been retrieved.
func (f *Find) Tailable(tailable bool) *Find {
if f == nil {
f = new(Find)
}
f.tailable = &tailable
return f
}
// Session sets the session for this operation.
func (f *Find) Session(session *session.Client) *Find {
if f == nil {
f = new(Find)
}
f.session = session
return f
}
// ClusterClock sets the cluster clock for this operation.
func (f *Find) ClusterClock(clock *session.ClusterClock) *Find {
if f == nil {
f = new(Find)
}
f.clock = clock
return f
}
// Collection sets the collection that this command will run against.
func (f *Find) Collection(collection string) *Find {
if f == nil {
f = new(Find)
}
f.collection = collection
return f
}
// CommandMonitor sets the monitor to use for APM events.
func (f *Find) CommandMonitor(monitor *event.CommandMonitor) *Find {
if f == nil {
f = new(Find)
}
f.monitor = monitor
return f
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (f *Find) Crypt(crypt driver.Crypt) *Find {
if f == nil {
f = new(Find)
}
f.crypt = crypt
return f
}
// Database sets the database to run this operation against.
func (f *Find) Database(database string) *Find {
if f == nil {
f = new(Find)
}
f.database = database
return f
}
// Deployment sets the deployment to use for this operation.
func (f *Find) Deployment(deployment driver.Deployment) *Find {
if f == nil {
f = new(Find)
}
f.deployment = deployment
return f
}
// ReadConcern specifies the read concern for this operation.
func (f *Find) ReadConcern(readConcern *readconcern.ReadConcern) *Find {
if f == nil {
f = new(Find)
}
f.readConcern = readConcern
return f
}
// ReadPreference set the read preference used with this operation.
func (f *Find) ReadPreference(readPreference *readpref.ReadPref) *Find {
if f == nil {
f = new(Find)
}
f.readPreference = readPreference
return f
}
// ServerSelector sets the selector used to retrieve a server.
func (f *Find) ServerSelector(selector description.ServerSelector) *Find {
if f == nil {
f = new(Find)
}
f.selector = selector
return f
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (f *Find) Retry(retry driver.RetryMode) *Find {
if f == nil {
f = new(Find)
}
f.retry = &retry
return f
}
// ServerAPI sets the server API version for this operation.
func (f *Find) ServerAPI(serverAPI *driver.ServerAPIOptions) *Find {
if f == nil {
f = new(Find)
}
f.serverAPI = serverAPI
return f
}
// Timeout sets the timeout for this operation.
func (f *Find) Timeout(timeout *time.Duration) *Find {
if f == nil {
f = new(Find)
}
f.timeout = timeout
return f
}
// Logger sets the logger for this operation.
func (f *Find) Logger(logger *logger.Logger) *Find {
if f == nil {
f = new(Find)
}
f.logger = logger
return f
}
// Authenticator sets the authenticator to use for this operation.
func (f *Find) Authenticator(authenticator driver.Authenticator) *Find {
if f == nil {
f = new(Find)
}
f.authenticator = authenticator
return f
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (f *Find) RawData(rawData bool) *Find {
if f == nil {
f = new(Find)
}
f.rawData = &rawData
return f
}
// OmitMaxTimeMS omits the automatically-calculated "maxTimeMS" from the
// command.
func (f *Find) OmitMaxTimeMS(omit bool) *Find {
if f == nil {
f = new(Find)
}
f.omitMaxTimeMS = omit
return f
}

View File

@@ -0,0 +1,502 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// FindAndModify performs a findAndModify operation.
type FindAndModify struct {
authenticator driver.Authenticator
arrayFilters bsoncore.Array
bypassDocumentValidation *bool
collation bsoncore.Document
comment bsoncore.Value
fields bsoncore.Document
newDocument *bool
query bsoncore.Document
remove *bool
sort bsoncore.Document
update bsoncore.Value
upsert *bool
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
crypt driver.Crypt
hint bsoncore.Value
serverAPI *driver.ServerAPIOptions
let bsoncore.Document
timeout *time.Duration
rawData *bool
additionalCmd bson.D
result FindAndModifyResult
}
// LastErrorObject represents information about updates and upserts returned by the server.
type LastErrorObject struct {
// True if an update modified an existing document
UpdatedExisting bool
// Object ID of the upserted document.
Upserted any
}
// FindAndModifyResult represents a findAndModify result returned by the server.
type FindAndModifyResult struct {
// Either the old or modified document, depending on the value of the new parameter.
Value bsoncore.Document
// Contains information about updates and upserts.
LastErrorObject LastErrorObject
}
func buildFindAndModifyResult(response bsoncore.Document) (FindAndModifyResult, error) {
elements, err := response.Elements()
if err != nil {
return FindAndModifyResult{}, err
}
famr := FindAndModifyResult{}
for _, element := range elements {
switch element.Key() {
case "value":
var ok bool
famr.Value, ok = element.Value().DocumentOK()
// The 'value' field returned by a FindAndModify can be null in the case that no document was found.
if element.Value().Type != bsoncore.TypeNull && !ok {
return famr, fmt.Errorf("response field 'value' is type document or null, but received BSON type %s", element.Value().Type)
}
case "lastErrorObject":
valDoc, ok := element.Value().DocumentOK()
if !ok {
return famr, fmt.Errorf("response field 'lastErrorObject' is type document, but received BSON type %s", element.Value().Type)
}
var leo LastErrorObject
if err = bson.Unmarshal(valDoc, &leo); err != nil {
return famr, err
}
famr.LastErrorObject = leo
}
}
return famr, nil
}
// NewFindAndModify constructs and returns a new FindAndModify.
func NewFindAndModify(query bsoncore.Document) *FindAndModify {
return &FindAndModify{
query: query,
}
}
// Result returns the result of executing this operation.
func (fam *FindAndModify) Result() FindAndModifyResult { return fam.result }
func (fam *FindAndModify) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
fam.result, err = buildFindAndModifyResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (fam *FindAndModify) Execute(ctx context.Context) error {
if fam.deployment == nil {
return errors.New("the FindAndModify operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: fam.command,
ProcessResponseFn: fam.processResponse,
RetryMode: fam.retry,
Type: driver.Write,
Client: fam.session,
Clock: fam.clock,
CommandMonitor: fam.monitor,
Database: fam.database,
Deployment: fam.deployment,
Selector: fam.selector,
WriteConcern: fam.writeConcern,
Crypt: fam.crypt,
ServerAPI: fam.serverAPI,
Timeout: fam.timeout,
Name: driverutil.FindAndModifyOp,
Authenticator: fam.authenticator,
}.Execute(ctx)
}
func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "findAndModify", fam.collection)
if fam.arrayFilters != nil {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 6) {
return nil, errors.New("the 'arrayFilters' command parameter requires a minimum server wire version of 6")
}
dst = bsoncore.AppendArrayElement(dst, "arrayFilters", fam.arrayFilters)
}
if fam.bypassDocumentValidation != nil {
dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *fam.bypassDocumentValidation)
}
if fam.collation != nil {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", fam.collation)
}
if fam.comment.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "comment", fam.comment)
}
if fam.fields != nil {
dst = bsoncore.AppendDocumentElement(dst, "fields", fam.fields)
}
if fam.newDocument != nil {
dst = bsoncore.AppendBooleanElement(dst, "new", *fam.newDocument)
}
if fam.query != nil {
dst = bsoncore.AppendDocumentElement(dst, "query", fam.query)
}
if fam.remove != nil {
dst = bsoncore.AppendBooleanElement(dst, "remove", *fam.remove)
}
if fam.sort != nil {
dst = bsoncore.AppendDocumentElement(dst, "sort", fam.sort)
}
if fam.update.Data != nil {
dst = bsoncore.AppendValueElement(dst, "update", fam.update)
}
if fam.upsert != nil {
dst = bsoncore.AppendBooleanElement(dst, "upsert", *fam.upsert)
}
if fam.hint.Type != bsoncore.Type(0) {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 8) {
return nil, errors.New("the 'hint' command parameter requires a minimum server wire version of 8")
}
if !fam.writeConcern.Acknowledged() {
return nil, errUnacknowledgedHint
}
dst = bsoncore.AppendValueElement(dst, "hint", fam.hint)
}
if fam.let != nil {
dst = bsoncore.AppendDocumentElement(dst, "let", fam.let)
}
// Set rawData for 8.2+ servers.
if fam.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *fam.rawData)
}
if len(fam.additionalCmd) > 0 {
doc, err := bson.Marshal(fam.additionalCmd)
if err != nil {
return nil, err
}
dst = append(dst, doc[4:len(doc)-1]...)
}
return dst, nil
}
// ArrayFilters specifies an array of filter documents that determines which array elements to modify for an update operation on an array field.
func (fam *FindAndModify) ArrayFilters(arrayFilters bsoncore.Array) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.arrayFilters = arrayFilters
return fam
}
// BypassDocumentValidation specifies if document validation can be skipped when executing the operation.
func (fam *FindAndModify) BypassDocumentValidation(bypassDocumentValidation bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.bypassDocumentValidation = &bypassDocumentValidation
return fam
}
// Collation specifies a collation to be used.
func (fam *FindAndModify) Collation(collation bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.collation = collation
return fam
}
// Comment sets a value to help trace an operation.
func (fam *FindAndModify) Comment(comment bsoncore.Value) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.comment = comment
return fam
}
// Fields specifies a subset of fields to return.
func (fam *FindAndModify) Fields(fields bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.fields = fields
return fam
}
// NewDocument specifies whether to return the modified document or the original. Defaults to false (return original).
func (fam *FindAndModify) NewDocument(newDocument bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.newDocument = &newDocument
return fam
}
// Query specifies the selection criteria for the modification.
func (fam *FindAndModify) Query(query bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.query = query
return fam
}
// Remove specifies that the matched document should be removed. Defaults to false.
func (fam *FindAndModify) Remove(remove bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.remove = &remove
return fam
}
// Sort determines which document the operation modifies if the query matches multiple documents.The first document matched by the sort order will be modified.
func (fam *FindAndModify) Sort(sort bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.sort = sort
return fam
}
// Update specifies the update document to perform on the matched document.
func (fam *FindAndModify) Update(update bsoncore.Value) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.update = update
return fam
}
// Upsert specifies whether or not to create a new document if no documents match the query when doing an update. Defaults to false.
func (fam *FindAndModify) Upsert(upsert bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.upsert = &upsert
return fam
}
// Session sets the session for this operation.
func (fam *FindAndModify) Session(session *session.Client) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.session = session
return fam
}
// ClusterClock sets the cluster clock for this operation.
func (fam *FindAndModify) ClusterClock(clock *session.ClusterClock) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.clock = clock
return fam
}
// Collection sets the collection that this command will run against.
func (fam *FindAndModify) Collection(collection string) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.collection = collection
return fam
}
// CommandMonitor sets the monitor to use for APM events.
func (fam *FindAndModify) CommandMonitor(monitor *event.CommandMonitor) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.monitor = monitor
return fam
}
// Database sets the database to run this operation against.
func (fam *FindAndModify) Database(database string) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.database = database
return fam
}
// Deployment sets the deployment to use for this operation.
func (fam *FindAndModify) Deployment(deployment driver.Deployment) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.deployment = deployment
return fam
}
// ServerSelector sets the selector used to retrieve a server.
func (fam *FindAndModify) ServerSelector(selector description.ServerSelector) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.selector = selector
return fam
}
// WriteConcern sets the write concern for this operation.
func (fam *FindAndModify) WriteConcern(writeConcern *writeconcern.WriteConcern) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.writeConcern = writeConcern
return fam
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (fam *FindAndModify) Retry(retry driver.RetryMode) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.retry = &retry
return fam
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (fam *FindAndModify) Crypt(crypt driver.Crypt) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.crypt = crypt
return fam
}
// Hint specifies the index to use.
func (fam *FindAndModify) Hint(hint bsoncore.Value) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.hint = hint
return fam
}
// ServerAPI sets the server API version for this operation.
func (fam *FindAndModify) ServerAPI(serverAPI *driver.ServerAPIOptions) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.serverAPI = serverAPI
return fam
}
// Let specifies the let document to use. This option is only valid for server versions 5.0 and above.
func (fam *FindAndModify) Let(let bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.let = let
return fam
}
// Timeout sets the timeout for this operation.
func (fam *FindAndModify) Timeout(timeout *time.Duration) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.timeout = timeout
return fam
}
// Authenticator sets the authenticator to use for this operation.
func (fam *FindAndModify) Authenticator(authenticator driver.Authenticator) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.authenticator = authenticator
return fam
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (fam *FindAndModify) RawData(rawData bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.rawData = &rawData
return fam
}
// AdditionalCmd sets additional command fields to be attached.
func (fam *FindAndModify) AdditionalCmd(d bson.D) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.additionalCmd = d
return fam
}

View File

@@ -0,0 +1,718 @@
// Copyright (C) MongoDB, Inc. 2021-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"os"
"runtime"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/bsonutil"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/internal/handshake"
"go.mongodb.org/mongo-driver/v2/mongo/address"
"go.mongodb.org/mongo-driver/v2/version"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// maxClientMetadataSize is the maximum size of the client metadata document
// that can be sent to the server. Note that the maximum document size on
// standalone and replica servers is 1024, but the maximum document size on
// sharded clusters is 512.
const maxClientMetadataSize = 512
const driverName = "mongo-go-driver"
// Hello is used to run the handshake operation.
type Hello struct {
authenticator driver.Authenticator
appname string
compressors []string
saslSupportedMechs string
d driver.Deployment
clock *session.ClusterClock
speculativeAuth bsoncore.Document
topologyVersion *description.TopologyVersion
maxAwaitTimeMS *int64
serverAPI *driver.ServerAPIOptions
loadBalanced bool
omitMaxTimeMS bool
// Fields provided by a library that wraps the Go Driver.
outerLibraryName string
outerLibraryVersion string
outerLibraryPlatform string
res bsoncore.Document
}
var _ driver.Handshaker = (*Hello)(nil)
// NewHello constructs a Hello.
func NewHello() *Hello { return &Hello{} }
// AppName sets the application name in the client metadata sent in this operation.
func (h *Hello) AppName(appname string) *Hello {
h.appname = appname
return h
}
// ClusterClock sets the cluster clock for this operation.
func (h *Hello) ClusterClock(clock *session.ClusterClock) *Hello {
if h == nil {
h = new(Hello)
}
h.clock = clock
return h
}
// Compressors sets the compressors that can be used.
func (h *Hello) Compressors(compressors []string) *Hello {
h.compressors = compressors
return h
}
// SASLSupportedMechs retrieves the supported SASL mechanism for the given user when this operation
// is run.
func (h *Hello) SASLSupportedMechs(username string) *Hello {
h.saslSupportedMechs = username
return h
}
// Deployment sets the Deployment for this operation.
func (h *Hello) Deployment(d driver.Deployment) *Hello {
h.d = d
return h
}
// SpeculativeAuthenticate sets the document to be used for speculative authentication.
func (h *Hello) SpeculativeAuthenticate(doc bsoncore.Document) *Hello {
h.speculativeAuth = doc
return h
}
// TopologyVersion sets the TopologyVersion to be used for heartbeats.
func (h *Hello) TopologyVersion(tv *description.TopologyVersion) *Hello {
h.topologyVersion = tv
return h
}
// MaxAwaitTimeMS sets the maximum time for the server to wait for topology changes during a heartbeat.
func (h *Hello) MaxAwaitTimeMS(awaitTime int64) *Hello {
h.maxAwaitTimeMS = &awaitTime
return h
}
// ServerAPI sets the server API version for this operation.
func (h *Hello) ServerAPI(serverAPI *driver.ServerAPIOptions) *Hello {
h.serverAPI = serverAPI
return h
}
// LoadBalanced specifies whether or not this operation is being sent over a connection to a load balanced cluster.
func (h *Hello) LoadBalanced(lb bool) *Hello {
h.loadBalanced = lb
return h
}
// OuterLibraryName specifies the name of the library wrapping the Go Driver.
func (h *Hello) OuterLibraryName(name string) *Hello {
h.outerLibraryName = name
return h
}
// OuterLibraryVersion specifies the version of the library wrapping the Go
// Driver.
func (h *Hello) OuterLibraryVersion(version string) *Hello {
h.outerLibraryVersion = version
return h
}
// OuterLibraryPlatform specifies the platform of the library wrapping the Go
// Driver.
func (h *Hello) OuterLibraryPlatform(platform string) *Hello {
h.outerLibraryPlatform = platform
return h
}
// Result returns the result of executing this operation.
func (h *Hello) Result(addr address.Address) description.Server {
return driverutil.NewServerDescription(addr, bson.Raw(h.res))
}
const dockerEnvPath = "/.dockerenv"
const (
// Runtime names
runtimeNameDocker = "docker"
// Orchestrator names
orchestratorNameK8s = "kubernetes"
)
// getFaasEnvName parses the FaaS environment variable name and returns the
// corresponding name used by the client. If none of the variables or variables
// for multiple names are populated the FaaS values MUST be entirely omitted.
// When variables for multiple "client.env.name" values are present, "vercel"
// takes precedence over "aws.lambda"; any other combination MUST cause FaaS
// values to be entirely omitted.
func getFaasEnvName() string {
envVars := []string{
driverutil.EnvVarAWSExecutionEnv,
driverutil.EnvVarAWSLambdaRuntimeAPI,
driverutil.EnvVarFunctionsWorkerRuntime,
driverutil.EnvVarKService,
driverutil.EnvVarFunctionName,
driverutil.EnvVarVercel,
}
// If none of the variables are populated the client.env value MUST be
// entirely omitted.
names := make(map[string]struct{})
for _, envVar := range envVars {
val := os.Getenv(envVar)
if val == "" {
continue
}
var name string
switch envVar {
case driverutil.EnvVarAWSExecutionEnv:
if !strings.HasPrefix(val, driverutil.AwsLambdaPrefix) {
continue
}
name = driverutil.EnvNameAWSLambda
case driverutil.EnvVarAWSLambdaRuntimeAPI:
name = driverutil.EnvNameAWSLambda
case driverutil.EnvVarFunctionsWorkerRuntime:
name = driverutil.EnvNameAzureFunc
case driverutil.EnvVarKService, driverutil.EnvVarFunctionName:
name = driverutil.EnvNameGCPFunc
case driverutil.EnvVarVercel:
// "vercel" takes precedence over "aws.lambda".
delete(names, driverutil.EnvNameAWSLambda)
name = driverutil.EnvNameVercel
}
names[name] = struct{}{}
if len(names) > 1 {
// If multiple names are populated the client.env value
// MUST be entirely omitted.
names = nil
break
}
}
for name := range names {
return name
}
return ""
}
type containerInfo struct {
runtime string
orchestrator string
}
// getContainerEnvInfo returns runtime and orchestrator of a container.
// If no fields is populated, the client.env.container value MUST be entirely
// omitted.
func getContainerEnvInfo() *containerInfo {
var runtime, orchestrator string
if _, err := os.Stat(dockerEnvPath); !os.IsNotExist(err) {
runtime = runtimeNameDocker
}
if v := os.Getenv(driverutil.EnvVarK8s); v != "" {
orchestrator = orchestratorNameK8s
}
if runtime != "" || orchestrator != "" {
return &containerInfo{
runtime: runtime,
orchestrator: orchestrator,
}
}
return nil
}
// appendClientAppName appends the application metadata to the dst. It is the
// responsibility of the caller to check that this appending does not cause dst
// to exceed any size limitations.
func appendClientAppName(dst []byte, name string) ([]byte, error) {
if name == "" {
return dst, nil
}
var idx int32
idx, dst = bsoncore.AppendDocumentElementStart(dst, "application")
dst = bsoncore.AppendStringElement(dst, "name", name)
return bsoncore.AppendDocumentEnd(dst, idx)
}
// appendClientDriver appends the driver metadata to dst. It is the
// responsibility of the caller to check that this appending does not cause dst
// to exceed any size limitations.
func appendClientDriver(dst []byte, outerLibraryName, outerLibraryVersion string) ([]byte, error) {
var idx int32
idx, dst = bsoncore.AppendDocumentElementStart(dst, "driver")
name := driverName
if outerLibraryName != "" {
name = name + "|" + outerLibraryName
}
version := version.Driver
if outerLibraryVersion != "" {
version = version + "|" + outerLibraryVersion
}
dst = bsoncore.AppendStringElement(dst, "name", name)
dst = bsoncore.AppendStringElement(dst, "version", version)
return bsoncore.AppendDocumentEnd(dst, idx)
}
// appendClientEnv appends the environment metadata to dst. It is the
// responsibility of the caller to check that this appending does not cause dst
// to exceed any size limitations.
func appendClientEnv(dst []byte, omitNonName, omitDoc bool) ([]byte, error) {
if omitDoc {
return dst, nil
}
name := getFaasEnvName()
container := getContainerEnvInfo()
// Omit the entire 'env' if both name and container are empty because other
// fields depend on either of them.
if name == "" && container == nil {
return dst, nil
}
var idx int32
idx, dst = bsoncore.AppendDocumentElementStart(dst, "env")
if name != "" {
dst = bsoncore.AppendStringElement(dst, "name", name)
}
addMem := func(envVar string) []byte {
mem := os.Getenv(envVar)
if mem == "" {
return dst
}
memInt64, err := strconv.ParseInt(mem, 10, 32)
if err != nil {
return dst
}
memInt32 := int32(memInt64)
return bsoncore.AppendInt32Element(dst, "memory_mb", memInt32)
}
addRegion := func(envVar string) []byte {
region := os.Getenv(envVar)
if region == "" {
return dst
}
return bsoncore.AppendStringElement(dst, "region", region)
}
addTimeout := func(envVar string) []byte {
timeout := os.Getenv(envVar)
if timeout == "" {
return dst
}
timeoutInt64, err := strconv.ParseInt(timeout, 10, 32)
if err != nil {
return dst
}
timeoutInt32 := int32(timeoutInt64)
return bsoncore.AppendInt32Element(dst, "timeout_sec", timeoutInt32)
}
if !omitNonName {
// No other FaaS fields will be populated if the name is empty.
switch name {
case driverutil.EnvNameAWSLambda:
dst = addMem(driverutil.EnvVarAWSLambdaFunctionMemorySize)
dst = addRegion(driverutil.EnvVarAWSRegion)
case driverutil.EnvNameGCPFunc:
dst = addMem(driverutil.EnvVarFunctionMemoryMB)
dst = addRegion(driverutil.EnvVarFunctionRegion)
dst = addTimeout(driverutil.EnvVarFunctionTimeoutSec)
case driverutil.EnvNameVercel:
dst = addRegion(driverutil.EnvVarVercelRegion)
}
}
if container != nil {
var idxCntnr int32
idxCntnr, dst = bsoncore.AppendDocumentElementStart(dst, "container")
if container.runtime != "" {
dst = bsoncore.AppendStringElement(dst, "runtime", container.runtime)
}
if container.orchestrator != "" {
dst = bsoncore.AppendStringElement(dst, "orchestrator", container.orchestrator)
}
var err error
dst, err = bsoncore.AppendDocumentEnd(dst, idxCntnr)
if err != nil {
return dst, err
}
}
return bsoncore.AppendDocumentEnd(dst, idx)
}
// appendClientOS appends the OS metadata to dst. It is the responsibility of the
// caller to check that this appending does not cause dst to exceed any size
// limitations.
func appendClientOS(dst []byte, omitNonType bool) ([]byte, error) {
var idx int32
idx, dst = bsoncore.AppendDocumentElementStart(dst, "os")
dst = bsoncore.AppendStringElement(dst, "type", runtime.GOOS)
if !omitNonType {
dst = bsoncore.AppendStringElement(dst, "architecture", runtime.GOARCH)
}
return bsoncore.AppendDocumentEnd(dst, idx)
}
// appendClientPlatform appends the platform metadata to dst. It is the
// responsibility of the caller to check that this appending does not cause dst
// to exceed any size limitations.
func appendClientPlatform(dst []byte, outerLibraryPlatform string) []byte {
platform := runtime.Version()
if outerLibraryPlatform != "" {
platform = platform + "|" + outerLibraryPlatform
}
return bsoncore.AppendStringElement(dst, "platform", platform)
}
// encodeClientMetadata encodes the client metadata into a BSON document. maxLen
// is the maximum length the document can be. If the document exceeds maxLen,
// then an empty byte slice is returned. If there is not enough space to encode
// a document, the document is truncated and returned.
//
// This function attempts to build the following document. Fields are omitted to
// save space following the MongoDB Handshake.
//
// {
// application: {
// name: "<string>"
// },
// driver: {
// name: "<string>",
// version: "<string>"
// },
// platform: "<string>",
// os: {
// type: "<string>",
// name: "<string>",
// architecture: "<string>",
// version: "<string>"
// },
// env: {
// name: "<string>",
// timeout_sec: 42,
// memory_mb: 1024,
// region: "<string>",
// container: {
// runtime: "<string>",
// orchestrator: "<string>"
// }
// }
// }
func encodeClientMetadata(h *Hello, maxLen int) ([]byte, error) {
dst := make([]byte, 0, maxLen)
omitEnvDoc := false
omitEnvNonName := false
omitOSNonType := false
omitEnvDocument := false
truncatePlatform := false
retry:
var idx int32
idx, dst = bsoncore.AppendDocumentStart(dst)
var err error
dst, err = appendClientAppName(dst, h.appname)
if err != nil {
return nil, err
}
dst, err = appendClientDriver(dst, h.outerLibraryName, h.outerLibraryVersion)
if err != nil {
return nil, err
}
dst, err = appendClientOS(dst, omitOSNonType)
if err != nil {
return nil, err
}
if !truncatePlatform {
dst = appendClientPlatform(dst, h.outerLibraryPlatform)
}
if !omitEnvDocument {
dst, err = appendClientEnv(dst, omitEnvNonName, omitEnvDoc)
if err != nil {
return nil, err
}
}
dst, err = bsoncore.AppendDocumentEnd(dst, idx)
if err != nil {
return nil, err
}
if len(dst) > maxLen {
// Implementers SHOULD cumulatively update fields in the
// following order until the document is under the size limit
//
// 1. Omit fields from ``env`` except ``env.name``
// 2. Omit fields from ``os`` except ``os.type``
// 3. Omit the ``env`` document entirely
// 4. Truncate ``platform``
dst = dst[:0]
if !omitEnvNonName {
omitEnvNonName = true
goto retry
}
if !omitOSNonType {
omitOSNonType = true
goto retry
}
if !omitEnvDoc {
omitEnvDoc = true
goto retry
}
if !truncatePlatform {
truncatePlatform = true
goto retry
}
// There is nothing left to update. Return an empty slice to
// tell caller not to append a `client` document.
return nil, nil
}
return dst, nil
}
// handshakeCommand appends all necessary command fields as well as client metadata, SASL supported mechs, and compression.
func (h *Hello) handshakeCommand(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst, err := h.command(dst, desc)
if err != nil {
return dst, err
}
if h.saslSupportedMechs != "" {
dst = bsoncore.AppendStringElement(dst, "saslSupportedMechs", h.saslSupportedMechs)
}
if h.speculativeAuth != nil {
dst = bsoncore.AppendDocumentElement(dst, "speculativeAuthenticate", h.speculativeAuth)
}
var idx int32
idx, dst = bsoncore.AppendArrayElementStart(dst, "compression")
for i, compressor := range h.compressors {
dst = bsoncore.AppendStringElement(dst, strconv.Itoa(i), compressor)
}
dst, _ = bsoncore.AppendArrayEnd(dst, idx)
clientMetadata, _ := encodeClientMetadata(h, maxClientMetadataSize)
// If the client metadata is empty, do not append it to the command.
if len(clientMetadata) > 0 {
dst = bsoncore.AppendDocumentElement(dst, "client", clientMetadata)
}
return dst, nil
}
// command appends all necessary command fields.
func (h *Hello) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
// Use "hello" if topology is LoadBalanced, API version is declared or server
// has responded with "helloOk". Otherwise, use legacy hello.
if h.loadBalanced || h.serverAPI != nil || desc.HelloOK {
dst = bsoncore.AppendInt32Element(dst, "hello", 1)
} else {
dst = bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1)
}
dst = bsoncore.AppendBooleanElement(dst, "helloOk", true)
if tv := h.topologyVersion; tv != nil {
var tvIdx int32
tvIdx, dst = bsoncore.AppendDocumentElementStart(dst, "topologyVersion")
dst = bsoncore.AppendObjectIDElement(dst, "processId", tv.ProcessID)
dst = bsoncore.AppendInt64Element(dst, "counter", tv.Counter)
dst, _ = bsoncore.AppendDocumentEnd(dst, tvIdx)
}
if h.maxAwaitTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxAwaitTimeMS", *h.maxAwaitTimeMS)
}
if h.loadBalanced {
// The loadBalanced parameter should only be added if it's true. We should never explicitly send
// loadBalanced=false per the load balancing spec.
dst = bsoncore.AppendBooleanElement(dst, "loadBalanced", true)
}
return dst, nil
}
// Execute runs this operation.
func (h *Hello) Execute(ctx context.Context) error {
if h.d == nil {
return errors.New("a Hello must have a Deployment set before Execute can be called")
}
return h.createOperation().Execute(ctx)
}
// StreamResponse gets the next streaming Hello response from the server.
func (h *Hello) StreamResponse(ctx context.Context, conn *mnet.Connection) error {
return h.createOperation().ExecuteExhaust(ctx, conn)
}
// isLegacyHandshake returns True if server API version is not requested and
// loadBalanced is False. If this is the case, then the drivers MUST use legacy
// hello for the first message of the initial handshake with the OP_QUERY
// protocol
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, loadbalanced bool) bool {
return srvAPI == nil && !loadbalanced
}
func (h *Hello) createOperation() driver.Operation {
op := driver.Operation{
Clock: h.clock,
CommandFn: h.command,
Database: "admin",
Deployment: h.d,
ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
h.res = resp
return nil
},
ServerAPI: h.serverAPI,
OmitMaxTimeMS: h.omitMaxTimeMS,
}
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
op.Legacy = driver.LegacyHandshake
}
return op
}
// GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant
// information about the server. This function implements the driver.Handshaker interface.
func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, conn *mnet.Connection) (driver.HandshakeInformation, error) {
deployment := driver.SingleConnectionDeployment{C: conn}
op := driver.Operation{
Clock: h.clock,
CommandFn: h.handshakeCommand,
Deployment: deployment,
Database: "admin",
ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
h.res = resp
return nil
},
ServerAPI: h.serverAPI,
}
if isLegacyHandshake(h.serverAPI, h.loadBalanced) {
op.Legacy = driver.LegacyHandshake
}
if err := op.Execute(ctx); err != nil {
return driver.HandshakeInformation{}, err
}
info := driver.HandshakeInformation{
Description: h.Result(conn.Address()),
}
if speculativeAuthenticate, ok := h.res.Lookup("speculativeAuthenticate").DocumentOK(); ok {
info.SpeculativeAuthenticate = speculativeAuthenticate
}
if serverConnectionID, ok := h.res.Lookup("connectionId").AsInt64OK(); ok {
info.ServerConnectionID = &serverConnectionID
}
var err error
// Cast to bson.Raw to lookup saslSupportedMechs to avoid converting from bsoncore.Value to bson.RawValue for the
// StringSliceFromRawValue call.
if saslSupportedMechs, lookupErr := bson.Raw(h.res).LookupErr("saslSupportedMechs"); lookupErr == nil {
info.SaslSupportedMechs, err = bsonutil.StringSliceFromRawValue("saslSupportedMechs", saslSupportedMechs)
}
return info, err
}
// FinishHandshake implements the Handshaker interface. This is a no-op function because a non-authenticated connection
// does not do anything besides the initial Hello for a handshake.
func (h *Hello) FinishHandshake(context.Context, *mnet.Connection) error {
return nil
}
// OmitMaxTimeMS will ensure maxTimMS is not included in the wire message
// constructed to send a hello request.
func (h *Hello) OmitMaxTimeMS(val bool) *Hello {
if h == nil {
h = new(Hello)
}
h.omitMaxTimeMS = val
return h
}
// Authenticator sets the authenticator to use for this operation.
func (h *Hello) Authenticator(authenticator driver.Authenticator) *Hello {
if h == nil {
h = new(Hello)
}
h.authenticator = authenticator
return h
}

View File

@@ -0,0 +1,291 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// ListCollections performs a listCollections operation.
type ListCollections struct {
authenticator driver.Authenticator
filter bsoncore.Document
nameOnly *bool
authorizedCollections *bool
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
readPreference *readpref.ReadPref
selector description.ServerSelector
retry *driver.RetryMode
result driver.CursorResponse
batchSize *int32
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
rawData *bool
}
// NewListCollections constructs and returns a new ListCollections.
func NewListCollections(filter bsoncore.Document) *ListCollections {
return &ListCollections{
filter: filter,
}
}
// Result returns the result of executing this operation.
func (lc *ListCollections) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) {
opts.ServerAPI = lc.serverAPI
return driver.NewBatchCursor(lc.result, lc.session, lc.clock, opts)
}
func (lc *ListCollections) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error {
curDoc, err := driver.ExtractCursorDocument(resp)
if err != nil {
return err
}
lc.result, err = driver.NewCursorResponse(curDoc, info)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (lc *ListCollections) Execute(ctx context.Context) error {
if lc.deployment == nil {
return errors.New("the ListCollections operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: lc.command,
ProcessResponseFn: lc.processResponse,
RetryMode: lc.retry,
Type: driver.Read,
Client: lc.session,
Clock: lc.clock,
CommandMonitor: lc.monitor,
Crypt: lc.crypt,
Database: lc.database,
Deployment: lc.deployment,
ReadPreference: lc.readPreference,
Selector: lc.selector,
Legacy: driver.LegacyListCollections,
ServerAPI: lc.serverAPI,
Timeout: lc.timeout,
Name: driverutil.ListCollectionsOp,
Authenticator: lc.authenticator,
}.Execute(ctx)
}
func (lc *ListCollections) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "listCollections", 1)
if lc.filter != nil {
dst = bsoncore.AppendDocumentElement(dst, "filter", lc.filter)
}
if lc.nameOnly != nil {
dst = bsoncore.AppendBooleanElement(dst, "nameOnly", *lc.nameOnly)
}
if lc.authorizedCollections != nil {
dst = bsoncore.AppendBooleanElement(dst, "authorizedCollections", *lc.authorizedCollections)
}
cursorDoc := bsoncore.NewDocumentBuilder()
if lc.batchSize != nil {
cursorDoc.AppendInt32("batchSize", *lc.batchSize)
}
dst = bsoncore.AppendDocumentElement(dst, "cursor", cursorDoc.Build())
// Set rawData for 8.2+ servers.
if lc.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *lc.rawData)
}
return dst, nil
}
// Filter determines what results are returned from listCollections.
func (lc *ListCollections) Filter(filter bsoncore.Document) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.filter = filter
return lc
}
// NameOnly specifies whether to only return collection names.
func (lc *ListCollections) NameOnly(nameOnly bool) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.nameOnly = &nameOnly
return lc
}
// AuthorizedCollections specifies whether to only return collections the user
// is authorized to use.
func (lc *ListCollections) AuthorizedCollections(authorizedCollections bool) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.authorizedCollections = &authorizedCollections
return lc
}
// Session sets the session for this operation.
func (lc *ListCollections) Session(session *session.Client) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.session = session
return lc
}
// ClusterClock sets the cluster clock for this operation.
func (lc *ListCollections) ClusterClock(clock *session.ClusterClock) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.clock = clock
return lc
}
// CommandMonitor sets the monitor to use for APM events.
func (lc *ListCollections) CommandMonitor(monitor *event.CommandMonitor) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.monitor = monitor
return lc
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (lc *ListCollections) Crypt(crypt driver.Crypt) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.crypt = crypt
return lc
}
// Database sets the database to run this operation against.
func (lc *ListCollections) Database(database string) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.database = database
return lc
}
// Deployment sets the deployment to use for this operation.
func (lc *ListCollections) Deployment(deployment driver.Deployment) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.deployment = deployment
return lc
}
// ReadPreference set the read preference used with this operation.
func (lc *ListCollections) ReadPreference(readPreference *readpref.ReadPref) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.readPreference = readPreference
return lc
}
// ServerSelector sets the selector used to retrieve a server.
func (lc *ListCollections) ServerSelector(selector description.ServerSelector) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.selector = selector
return lc
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (lc *ListCollections) Retry(retry driver.RetryMode) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.retry = &retry
return lc
}
// BatchSize specifies the number of documents to return in every batch.
func (lc *ListCollections) BatchSize(batchSize int32) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.batchSize = &batchSize
return lc
}
// ServerAPI sets the server API version for this operation.
func (lc *ListCollections) ServerAPI(serverAPI *driver.ServerAPIOptions) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.serverAPI = serverAPI
return lc
}
// Timeout sets the timeout for this operation.
func (lc *ListCollections) Timeout(timeout *time.Duration) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.timeout = timeout
return lc
}
// Authenticator sets the authenticator to use for this operation.
func (lc *ListCollections) Authenticator(authenticator driver.Authenticator) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.authenticator = authenticator
return lc
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (lc *ListCollections) RawData(rawData bool) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.rawData = &rawData
return lc
}

View File

@@ -0,0 +1,336 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// ListDatabases performs a listDatabases operation.
type ListDatabases struct {
authenticator driver.Authenticator
filter bsoncore.Document
authorizedDatabases *bool
nameOnly *bool
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
database string
deployment driver.Deployment
readPreference *readpref.ReadPref
retry *driver.RetryMode
selector description.ServerSelector
crypt driver.Crypt
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
result ListDatabasesResult
}
// ListDatabasesResult represents a listDatabases result returned by the server.
type ListDatabasesResult struct {
// An array of documents, one document for each database
Databases []databaseRecord
// The sum of the size of all the database files on disk in bytes.
TotalSize int64
}
type databaseRecord struct {
Name string
SizeOnDisk int64 `bson:"sizeOnDisk"`
Empty bool
}
func buildListDatabasesResult(response bsoncore.Document) (ListDatabasesResult, error) {
elements, err := response.Elements()
if err != nil {
return ListDatabasesResult{}, err
}
ir := ListDatabasesResult{}
for _, element := range elements {
switch element.Key() {
case "totalSize":
var ok bool
ir.TotalSize, ok = element.Value().AsInt64OK()
if !ok {
return ir, fmt.Errorf("response field 'totalSize' is type int64, but received BSON type %s: %s", element.Value().Type, element.Value())
}
case "databases":
arr, ok := element.Value().ArrayOK()
if !ok {
return ir, fmt.Errorf("response field 'databases' is type array, but received BSON type %s", element.Value().Type)
}
var tmp bsoncore.Document
err := bson.Unmarshal(arr, &tmp)
if err != nil {
return ir, err
}
records, err := tmp.Elements()
if err != nil {
return ir, err
}
ir.Databases = make([]databaseRecord, len(records))
for i, val := range records {
valueDoc, ok := val.Value().DocumentOK()
if !ok {
return ir, fmt.Errorf("'databases' element is type document, but received BSON type %s", val.Value().Type)
}
elems, err := valueDoc.Elements()
if err != nil {
return ir, err
}
for _, elem := range elems {
switch elem.Key() {
case "name":
ir.Databases[i].Name, ok = elem.Value().StringValueOK()
if !ok {
return ir, fmt.Errorf("response field 'name' is type string, but received BSON type %s", elem.Value().Type)
}
case "sizeOnDisk":
ir.Databases[i].SizeOnDisk, ok = elem.Value().AsInt64OK()
if !ok {
return ir, fmt.Errorf("response field 'sizeOnDisk' is type int64, but received BSON type %s", elem.Value().Type)
}
case "empty":
ir.Databases[i].Empty, ok = elem.Value().BooleanOK()
if !ok {
return ir, fmt.Errorf("response field 'empty' is type bool, but received BSON type %s", elem.Value().Type)
}
}
}
}
}
}
return ir, nil
}
// NewListDatabases constructs and returns a new ListDatabases.
func NewListDatabases(filter bsoncore.Document) *ListDatabases {
return &ListDatabases{
filter: filter,
}
}
// Result returns the result of executing this operation.
func (ld *ListDatabases) Result() ListDatabasesResult { return ld.result }
func (ld *ListDatabases) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
ld.result, err = buildListDatabasesResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (ld *ListDatabases) Execute(ctx context.Context) error {
if ld.deployment == nil {
return errors.New("the ListDatabases operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: ld.command,
ProcessResponseFn: ld.processResponse,
Client: ld.session,
Clock: ld.clock,
CommandMonitor: ld.monitor,
Database: ld.database,
Deployment: ld.deployment,
ReadPreference: ld.readPreference,
RetryMode: ld.retry,
Type: driver.Read,
Selector: ld.selector,
Crypt: ld.crypt,
ServerAPI: ld.serverAPI,
Timeout: ld.timeout,
Name: driverutil.ListDatabasesOp,
Authenticator: ld.authenticator,
}.Execute(ctx)
}
func (ld *ListDatabases) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "listDatabases", 1)
if ld.filter != nil {
dst = bsoncore.AppendDocumentElement(dst, "filter", ld.filter)
}
if ld.nameOnly != nil {
dst = bsoncore.AppendBooleanElement(dst, "nameOnly", *ld.nameOnly)
}
if ld.authorizedDatabases != nil {
dst = bsoncore.AppendBooleanElement(dst, "authorizedDatabases", *ld.authorizedDatabases)
}
return dst, nil
}
// Filter determines what results are returned from listDatabases.
func (ld *ListDatabases) Filter(filter bsoncore.Document) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.filter = filter
return ld
}
// NameOnly specifies whether to only return database names.
func (ld *ListDatabases) NameOnly(nameOnly bool) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.nameOnly = &nameOnly
return ld
}
// AuthorizedDatabases specifies whether to only return databases which the user is authorized to use."
func (ld *ListDatabases) AuthorizedDatabases(authorizedDatabases bool) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.authorizedDatabases = &authorizedDatabases
return ld
}
// Session sets the session for this operation.
func (ld *ListDatabases) Session(session *session.Client) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.session = session
return ld
}
// ClusterClock sets the cluster clock for this operation.
func (ld *ListDatabases) ClusterClock(clock *session.ClusterClock) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.clock = clock
return ld
}
// CommandMonitor sets the monitor to use for APM events.
func (ld *ListDatabases) CommandMonitor(monitor *event.CommandMonitor) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.monitor = monitor
return ld
}
// Database sets the database to run this operation against.
func (ld *ListDatabases) Database(database string) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.database = database
return ld
}
// Deployment sets the deployment to use for this operation.
func (ld *ListDatabases) Deployment(deployment driver.Deployment) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.deployment = deployment
return ld
}
// ReadPreference set the read preference used with this operation.
func (ld *ListDatabases) ReadPreference(readPreference *readpref.ReadPref) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.readPreference = readPreference
return ld
}
// ServerSelector sets the selector used to retrieve a server.
func (ld *ListDatabases) ServerSelector(selector description.ServerSelector) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.selector = selector
return ld
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (ld *ListDatabases) Retry(retry driver.RetryMode) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.retry = &retry
return ld
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (ld *ListDatabases) Crypt(crypt driver.Crypt) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.crypt = crypt
return ld
}
// ServerAPI sets the server API version for this operation.
func (ld *ListDatabases) ServerAPI(serverAPI *driver.ServerAPIOptions) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.serverAPI = serverAPI
return ld
}
// Timeout sets the timeout for this operation.
func (ld *ListDatabases) Timeout(timeout *time.Duration) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.timeout = timeout
return ld
}
// Authenticator sets the authenticator to use for this operation.
func (ld *ListDatabases) Authenticator(authenticator driver.Authenticator) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.authenticator = authenticator
return ld
}

View File

@@ -0,0 +1,248 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// ListIndexes performs a listIndexes operation.
type ListIndexes struct {
authenticator driver.Authenticator
batchSize *int32
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
retry *driver.RetryMode
crypt driver.Crypt
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
rawData *bool
result driver.CursorResponse
}
// NewListIndexes constructs and returns a new ListIndexes.
func NewListIndexes() *ListIndexes {
return &ListIndexes{}
}
// Result returns the result of executing this operation.
func (li *ListIndexes) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) {
clientSession := li.session
clock := li.clock
opts.ServerAPI = li.serverAPI
return driver.NewBatchCursor(li.result, clientSession, clock, opts)
}
func (li *ListIndexes) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error {
curDoc, err := driver.ExtractCursorDocument(resp)
if err != nil {
return err
}
li.result, err = driver.NewCursorResponse(curDoc, info)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (li *ListIndexes) Execute(ctx context.Context) error {
if li.deployment == nil {
return errors.New("the ListIndexes operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: li.command,
ProcessResponseFn: li.processResponse,
Client: li.session,
Clock: li.clock,
CommandMonitor: li.monitor,
Database: li.database,
Deployment: li.deployment,
Selector: li.selector,
Crypt: li.crypt,
Legacy: driver.LegacyListIndexes,
RetryMode: li.retry,
Type: driver.Read,
ServerAPI: li.serverAPI,
Timeout: li.timeout,
Name: driverutil.ListIndexesOp,
Authenticator: li.authenticator,
}.Execute(ctx)
}
func (li *ListIndexes) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "listIndexes", li.collection)
cursorIdx, cursorDoc := bsoncore.AppendDocumentStart(nil)
if li.batchSize != nil {
cursorDoc = bsoncore.AppendInt32Element(cursorDoc, "batchSize", *li.batchSize)
}
cursorDoc, _ = bsoncore.AppendDocumentEnd(cursorDoc, cursorIdx)
dst = bsoncore.AppendDocumentElement(dst, "cursor", cursorDoc)
// Set rawData for 8.2+ servers.
if li.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *li.rawData)
}
return dst, nil
}
// BatchSize specifies the number of documents to return in every batch.
func (li *ListIndexes) BatchSize(batchSize int32) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.batchSize = &batchSize
return li
}
// Session sets the session for this operation.
func (li *ListIndexes) Session(session *session.Client) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.session = session
return li
}
// ClusterClock sets the cluster clock for this operation.
func (li *ListIndexes) ClusterClock(clock *session.ClusterClock) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.clock = clock
return li
}
// Collection sets the collection that this command will run against.
func (li *ListIndexes) Collection(collection string) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.collection = collection
return li
}
// CommandMonitor sets the monitor to use for APM events.
func (li *ListIndexes) CommandMonitor(monitor *event.CommandMonitor) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.monitor = monitor
return li
}
// Database sets the database to run this operation against.
func (li *ListIndexes) Database(database string) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.database = database
return li
}
// Deployment sets the deployment to use for this operation.
func (li *ListIndexes) Deployment(deployment driver.Deployment) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.deployment = deployment
return li
}
// ServerSelector sets the selector used to retrieve a server.
func (li *ListIndexes) ServerSelector(selector description.ServerSelector) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.selector = selector
return li
}
// Retry enables retryable mode for this operation. Retries are handled automatically in driver.Operation.Execute based
// on how the operation is set.
func (li *ListIndexes) Retry(retry driver.RetryMode) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.retry = &retry
return li
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (li *ListIndexes) Crypt(crypt driver.Crypt) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.crypt = crypt
return li
}
// ServerAPI sets the server API version for this operation.
func (li *ListIndexes) ServerAPI(serverAPI *driver.ServerAPIOptions) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.serverAPI = serverAPI
return li
}
// Timeout sets the timeout for this operation.
func (li *ListIndexes) Timeout(timeout *time.Duration) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.timeout = timeout
return li
}
// Authenticator sets the authenticator to use for this operation.
func (li *ListIndexes) Authenticator(authenticator driver.Authenticator) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.authenticator = authenticator
return li
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (li *ListIndexes) RawData(rawData bool) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.rawData = &rawData
return li
}

View File

@@ -0,0 +1,454 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
"go.mongodb.org/mongo-driver/v2/internal/logger"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// Update performs an update operation.
type Update struct {
authenticator driver.Authenticator
bypassDocumentValidation *bool
comment bsoncore.Value
ordered *bool
updates []bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
hint *bool
arrayFilters *bool
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
result UpdateResult
crypt driver.Crypt
serverAPI *driver.ServerAPIOptions
let bsoncore.Document
timeout *time.Duration
rawData *bool
additionalCmd bson.D
logger *logger.Logger
}
// Upsert contains the information for an upsert in an Update operation.
type Upsert struct {
Index int64
ID any `bson:"_id"`
}
// UpdateResult contains information for the result of an Update operation.
type UpdateResult struct {
// Number of documents matched.
N int64
// Number of documents modified.
NModified int64
// Information about upserted documents.
Upserted []Upsert
}
func buildUpdateResult(response bsoncore.Document) (UpdateResult, error) {
elements, err := response.Elements()
if err != nil {
return UpdateResult{}, err
}
ur := UpdateResult{}
for _, element := range elements {
switch element.Key() {
case "nModified":
var ok bool
ur.NModified, ok = element.Value().AsInt64OK()
if !ok {
return ur, fmt.Errorf("response field 'nModified' is type int32 or int64, but received BSON type %s", element.Value().Type)
}
case "n":
var ok bool
ur.N, ok = element.Value().AsInt64OK()
if !ok {
return ur, fmt.Errorf("response field 'n' is type int32 or int64, but received BSON type %s", element.Value().Type)
}
case "upserted":
arr, ok := element.Value().ArrayOK()
if !ok {
return ur, fmt.Errorf("response field 'upserted' is type array, but received BSON type %s", element.Value().Type)
}
var values []bsoncore.Value
values, err = arr.Values()
if err != nil {
break
}
for _, val := range values {
valDoc, ok := val.DocumentOK()
if !ok {
return ur, fmt.Errorf("upserted value is type document, but received BSON type %s", val.Type)
}
var upsert Upsert
if err = bson.Unmarshal(valDoc, &upsert); err != nil {
return ur, err
}
ur.Upserted = append(ur.Upserted, upsert)
}
}
}
return ur, nil
}
// NewUpdate constructs and returns a new Update.
func NewUpdate(updates ...bsoncore.Document) *Update {
return &Update{
updates: updates,
}
}
// Result returns the result of executing this operation.
func (u *Update) Result() UpdateResult { return u.result }
func (u *Update) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error {
ur, err := buildUpdateResult(resp)
u.result.N += ur.N
u.result.NModified += ur.NModified
if info.CurrentIndex > 0 {
for ind := range ur.Upserted {
ur.Upserted[ind].Index += int64(info.CurrentIndex)
}
}
u.result.Upserted = append(u.result.Upserted, ur.Upserted...)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (u *Update) Execute(ctx context.Context) error {
if u.deployment == nil {
return errors.New("the Update operation must have a Deployment set before Execute can be called")
}
batches := &driver.Batches{
Identifier: "updates",
Documents: u.updates,
Ordered: u.ordered,
}
return driver.Operation{
CommandFn: u.command,
ProcessResponseFn: u.processResponse,
Batches: batches,
RetryMode: u.retry,
Type: driver.Write,
Client: u.session,
Clock: u.clock,
CommandMonitor: u.monitor,
Database: u.database,
Deployment: u.deployment,
Selector: u.selector,
WriteConcern: u.writeConcern,
Crypt: u.crypt,
ServerAPI: u.serverAPI,
Timeout: u.timeout,
Logger: u.logger,
Name: driverutil.UpdateOp,
Authenticator: u.authenticator,
}.Execute(ctx)
}
func (u *Update) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "update", u.collection)
if u.bypassDocumentValidation != nil &&
(desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 4)) {
dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *u.bypassDocumentValidation)
}
if u.comment.Type != bsoncore.Type(0) {
dst = bsoncore.AppendValueElement(dst, "comment", u.comment)
}
if u.ordered != nil {
dst = bsoncore.AppendBooleanElement(dst, "ordered", *u.ordered)
}
if u.hint != nil && *u.hint {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) {
return nil, errors.New("the 'hint' command parameter requires a minimum server wire version of 5")
}
if !u.writeConcern.Acknowledged() {
return nil, errUnacknowledgedHint
}
}
if u.arrayFilters != nil && *u.arrayFilters {
if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 6) {
return nil, errors.New("the 'arrayFilters' command parameter requires a minimum server wire version of 6")
}
}
if u.let != nil {
dst = bsoncore.AppendDocumentElement(dst, "let", u.let)
}
// Set rawData for 8.2+ servers.
if u.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
dst = bsoncore.AppendBooleanElement(dst, "rawData", *u.rawData)
}
if len(u.additionalCmd) > 0 {
doc, err := bson.Marshal(u.additionalCmd)
if err != nil {
return nil, err
}
dst = append(dst, doc[4:len(doc)-1]...)
}
return dst, nil
}
// BypassDocumentValidation allows the operation to opt-out of document level validation.
func (u *Update) BypassDocumentValidation(bypassDocumentValidation bool) *Update {
if u == nil {
u = new(Update)
}
u.bypassDocumentValidation = &bypassDocumentValidation
return u
}
// Hint is a flag to indicate that the update document contains a hint. Hint is only supported by
// servers >= 4.2. Older servers will report an error for using the hint option.
func (u *Update) Hint(hint bool) *Update {
if u == nil {
u = new(Update)
}
u.hint = &hint
return u
}
// ArrayFilters is a flag to indicate that the update document contains an arrayFilters field.
func (u *Update) ArrayFilters(arrayFilters bool) *Update {
if u == nil {
u = new(Update)
}
u.arrayFilters = &arrayFilters
return u
}
// Ordered sets ordered. If true, when a write fails, the operation will return the error, when
// false write failures do not stop execution of the operation.
func (u *Update) Ordered(ordered bool) *Update {
if u == nil {
u = new(Update)
}
u.ordered = &ordered
return u
}
// Updates specifies an array of update statements to perform when this operation is executed.
// Each update document must have the following structure:
// {q: <query>, u: <update>, multi: <boolean>, collation: Optional<Document>, arrayFitlers: Optional<Array>, hint: Optional<string/Document>}.
func (u *Update) Updates(updates ...bsoncore.Document) *Update {
if u == nil {
u = new(Update)
}
u.updates = updates
return u
}
// Session sets the session for this operation.
func (u *Update) Session(session *session.Client) *Update {
if u == nil {
u = new(Update)
}
u.session = session
return u
}
// ClusterClock sets the cluster clock for this operation.
func (u *Update) ClusterClock(clock *session.ClusterClock) *Update {
if u == nil {
u = new(Update)
}
u.clock = clock
return u
}
// Collection sets the collection that this command will run against.
func (u *Update) Collection(collection string) *Update {
if u == nil {
u = new(Update)
}
u.collection = collection
return u
}
// CommandMonitor sets the monitor to use for APM events.
func (u *Update) CommandMonitor(monitor *event.CommandMonitor) *Update {
if u == nil {
u = new(Update)
}
u.monitor = monitor
return u
}
// Comment sets a value to help trace an operation.
func (u *Update) Comment(comment bsoncore.Value) *Update {
if u == nil {
u = new(Update)
}
u.comment = comment
return u
}
// Database sets the database to run this operation against.
func (u *Update) Database(database string) *Update {
if u == nil {
u = new(Update)
}
u.database = database
return u
}
// Deployment sets the deployment to use for this operation.
func (u *Update) Deployment(deployment driver.Deployment) *Update {
if u == nil {
u = new(Update)
}
u.deployment = deployment
return u
}
// ServerSelector sets the selector used to retrieve a server.
func (u *Update) ServerSelector(selector description.ServerSelector) *Update {
if u == nil {
u = new(Update)
}
u.selector = selector
return u
}
// WriteConcern sets the write concern for this operation.
func (u *Update) WriteConcern(writeConcern *writeconcern.WriteConcern) *Update {
if u == nil {
u = new(Update)
}
u.writeConcern = writeConcern
return u
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (u *Update) Retry(retry driver.RetryMode) *Update {
if u == nil {
u = new(Update)
}
u.retry = &retry
return u
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (u *Update) Crypt(crypt driver.Crypt) *Update {
if u == nil {
u = new(Update)
}
u.crypt = crypt
return u
}
// ServerAPI sets the server API version for this operation.
func (u *Update) ServerAPI(serverAPI *driver.ServerAPIOptions) *Update {
if u == nil {
u = new(Update)
}
u.serverAPI = serverAPI
return u
}
// Let specifies the let document to use. This option is only valid for server versions 5.0 and above.
func (u *Update) Let(let bsoncore.Document) *Update {
if u == nil {
u = new(Update)
}
u.let = let
return u
}
// Timeout sets the timeout for this operation.
func (u *Update) Timeout(timeout *time.Duration) *Update {
if u == nil {
u = new(Update)
}
u.timeout = timeout
return u
}
// Logger sets the logger for this operation.
func (u *Update) Logger(logger *logger.Logger) *Update {
if u == nil {
u = new(Update)
}
u.logger = logger
return u
}
// Authenticator sets the authenticator to use for this operation.
func (u *Update) Authenticator(authenticator driver.Authenticator) *Update {
if u == nil {
u = new(Update)
}
u.authenticator = authenticator
return u
}
// RawData sets the rawData to access timeseries data in the compressed format.
func (u *Update) RawData(rawData bool) *Update {
if u == nil {
u = new(Update)
}
u.rawData = &rawData
return u
}
// AdditionalCmd sets additional command fields to be attached.
func (u *Update) AdditionalCmd(d bson.D) *Update {
if u == nil {
u = new(Update)
}
u.additionalCmd = d
return u
}

View File

@@ -0,0 +1,237 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package operation
import (
"context"
"errors"
"fmt"
"time"
"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
// UpdateSearchIndex performs a updateSearchIndex operation.
type UpdateSearchIndex struct {
authenticator driver.Authenticator
index string
definition bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
crypt driver.Crypt
database string
deployment driver.Deployment
selector description.ServerSelector
result UpdateSearchIndexResult
serverAPI *driver.ServerAPIOptions
timeout *time.Duration
}
// UpdateSearchIndexResult represents a single index in the updateSearchIndexResult result.
type UpdateSearchIndexResult struct {
Ok int32
}
func buildUpdateSearchIndexResult(response bsoncore.Document) (UpdateSearchIndexResult, error) {
elements, err := response.Elements()
if err != nil {
return UpdateSearchIndexResult{}, err
}
usir := UpdateSearchIndexResult{}
for _, element := range elements {
if element.Key() == "ok" {
var ok bool
usir.Ok, ok = element.Value().AsInt32OK()
if !ok {
return usir, fmt.Errorf("response field 'ok' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return usir, nil
}
// NewUpdateSearchIndex constructs and returns a new UpdateSearchIndex.
func NewUpdateSearchIndex(index string, definition bsoncore.Document) *UpdateSearchIndex {
return &UpdateSearchIndex{
index: index,
definition: definition,
}
}
// Result returns the result of executing this operation.
func (usi *UpdateSearchIndex) Result() UpdateSearchIndexResult { return usi.result }
func (usi *UpdateSearchIndex) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error {
var err error
usi.result, err = buildUpdateSearchIndexResult(resp)
return err
}
// Execute runs this operations and returns an error if the operation did not execute successfully.
func (usi *UpdateSearchIndex) Execute(ctx context.Context) error {
if usi.deployment == nil {
return errors.New("the UpdateSearchIndex operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: usi.command,
ProcessResponseFn: usi.processResponse,
Client: usi.session,
Clock: usi.clock,
CommandMonitor: usi.monitor,
Crypt: usi.crypt,
Database: usi.database,
Deployment: usi.deployment,
Selector: usi.selector,
ServerAPI: usi.serverAPI,
Timeout: usi.timeout,
Authenticator: usi.authenticator,
}.Execute(ctx)
}
func (usi *UpdateSearchIndex) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "updateSearchIndex", usi.collection)
dst = bsoncore.AppendStringElement(dst, "name", usi.index)
dst = bsoncore.AppendDocumentElement(dst, "definition", usi.definition)
return dst, nil
}
// Index specifies the index of the document being updated.
func (usi *UpdateSearchIndex) Index(name string) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.index = name
return usi
}
// Definition specifies the definition for the document being created.
func (usi *UpdateSearchIndex) Definition(definition bsoncore.Document) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.definition = definition
return usi
}
// Session sets the session for this operation.
func (usi *UpdateSearchIndex) Session(session *session.Client) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.session = session
return usi
}
// ClusterClock sets the cluster clock for this operation.
func (usi *UpdateSearchIndex) ClusterClock(clock *session.ClusterClock) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.clock = clock
return usi
}
// Collection sets the collection that this command will run against.
func (usi *UpdateSearchIndex) Collection(collection string) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.collection = collection
return usi
}
// CommandMonitor sets the monitor to use for APM events.
func (usi *UpdateSearchIndex) CommandMonitor(monitor *event.CommandMonitor) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.monitor = monitor
return usi
}
// Crypt sets the Crypt object to use for automatic encryption and decryption.
func (usi *UpdateSearchIndex) Crypt(crypt driver.Crypt) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.crypt = crypt
return usi
}
// Database sets the database to run this operation against.
func (usi *UpdateSearchIndex) Database(database string) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.database = database
return usi
}
// Deployment sets the deployment to use for this operation.
func (usi *UpdateSearchIndex) Deployment(deployment driver.Deployment) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.deployment = deployment
return usi
}
// ServerSelector sets the selector used to retrieve a server.
func (usi *UpdateSearchIndex) ServerSelector(selector description.ServerSelector) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.selector = selector
return usi
}
// ServerAPI sets the server API version for this operation.
func (usi *UpdateSearchIndex) ServerAPI(serverAPI *driver.ServerAPIOptions) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.serverAPI = serverAPI
return usi
}
// Timeout sets the timeout for this operation.
func (usi *UpdateSearchIndex) Timeout(timeout *time.Duration) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.timeout = timeout
return usi
}
// Authenticator sets the authenticator to use for this operation.
func (usi *UpdateSearchIndex) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndex {
if usi == nil {
usi = new(UpdateSearchIndex)
}
usi.authenticator = authenticator
return usi
}

View File

@@ -0,0 +1,38 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
)
// ExecuteExhaust reads a response from the provided StreamerConnection. This will error if the connection's
// CurrentlyStreaming function returns false.
func (op Operation) ExecuteExhaust(ctx context.Context, conn *mnet.Connection) error {
if !conn.CurrentlyStreaming() {
return errors.New("exhaust read must be done with a connection that is currently streaming")
}
res, err := op.readWireMessage(ctx, conn)
if err != nil {
return err
}
if op.ProcessResponseFn != nil {
// Server, ConnectionDescription, and CurrentIndex are unused in this mode.
info := ResponseInfo{
Connection: conn,
}
if err = op.ProcessResponseFn(ctx, res, info); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,36 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
// TestServerAPIVersion is the most recent, stable variant of options.ServerAPIVersion.
// Only to be used in testing.
const TestServerAPIVersion = "1"
// ServerAPIOptions represents arguments used to configure the API version sent
// to the server when running commands.
type ServerAPIOptions struct {
ServerAPIVersion string
Strict *bool
DeprecationErrors *bool
}
// NewServerAPIOptions creates a new ServerAPIOptions configured with the provided serverAPIVersion.
func NewServerAPIOptions(serverAPIVersion string) *ServerAPIOptions {
return &ServerAPIOptions{ServerAPIVersion: serverAPIVersion}
}
// SetStrict specifies whether the server should return errors for features that are not part of the API version.
func (s *ServerAPIOptions) SetStrict(strict bool) *ServerAPIOptions {
s.Strict = &strict
return s
}
// SetDeprecationErrors specifies whether the server should return errors for deprecated features.
func (s *ServerAPIOptions) SetDeprecationErrors(deprecationErrors bool) *ServerAPIOptions {
s.DeprecationErrors = &deprecationErrors
return s
}

View File

@@ -0,0 +1,554 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"errors"
"time"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/uuid"
"go.mongodb.org/mongo-driver/v2/mongo/address"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
)
// ErrSessionEnded is returned when a client session is used after a call to endSession().
var ErrSessionEnded = errors.New("ended session was used")
// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started.
var ErrNoTransactStarted = errors.New("no transaction started")
// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress.
var ErrTransactInProgress = errors.New("transaction already in progress")
// ErrAbortAfterCommit is returned when abort is called after a commit.
var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction")
// ErrAbortTwice is returned if abort is called after transaction is already aborted.
var ErrAbortTwice = errors.New("cannot call abortTransaction twice")
// ErrCommitAfterAbort is returned if commit is called after an abort.
var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction")
// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaction.
var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns")
// ErrSnapshotTransaction is returned if an transaction is started on a snapshot session.
var ErrSnapshotTransaction = errors.New("transactions are not supported in snapshot sessions")
// TransactionState indicates the state of the transactions FSM.
type TransactionState uint8
// Client Session states
const (
None TransactionState = iota
Starting
InProgress
Committed
Aborted
)
const defaultWriteConcernTimeout = 10_000 * time.Millisecond
// String implements the fmt.Stringer interface.
func (s TransactionState) String() string {
switch s {
case None:
return "none"
case Starting:
return "starting"
case InProgress:
return "in progress"
case Committed:
return "committed"
case Aborted:
return "aborted"
default:
return "unknown"
}
}
var _ mnet.Pinner = (LoadBalancedTransactionConnection)(nil)
// LoadBalancedTransactionConnection represents a connection that's pinned by a ClientSession because it's being used
// to execute a transaction when running against a load balancer. This interface is a copy of driver.PinnedConnection
// and exists to be able to pin transactions to a connection without causing an import cycle.
type LoadBalancedTransactionConnection interface {
mnet.ReadWriteCloser
mnet.Describer
mnet.Pinner
}
// Client is a session for clients to run commands.
type Client struct {
*Server
ClientID uuid.UUID
ClusterTime bson.Raw
Consistent bool // causal consistency
OperationTime *bson.Timestamp
IsImplicit bool
Terminated bool
RetryingCommit bool
Committing bool
Aborting bool
Snapshot bool
// SnapshotTime is the atClusterTime value for snapshot reads. This field is
// left immutable once set for the lifetime of the session. This guards
// against users updating custom snapshot times during transactions which
// could lead to a write conflict.
SnapshotTime bson.Timestamp
SnapshotTimeSet bool
// options for the current transaction
// most recently set by transactionopt
CurrentRc *readconcern.ReadConcern
CurrentRp *readpref.ReadPref
CurrentWc *writeconcern.WriteConcern
CurrentWTimeout time.Duration
// default transaction options
transactionRc *readconcern.ReadConcern
transactionRp *readpref.ReadPref
transactionWc *writeconcern.WriteConcern
pool *Pool
TransactionState TransactionState
PinnedServerAddr *address.Address
RecoveryToken bson.Raw
PinnedConnection LoadBalancedTransactionConnection
}
func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
if clusterTime == nil {
return 0, 0
}
clusterTimeVal, err := clusterTime.LookupErr("$clusterTime")
if err != nil {
return 0, 0
}
timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime")
if err != nil {
return 0, 0
}
return timestampVal.Timestamp()
}
// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time.
func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw {
epoch1, ord1 := getClusterTime(ct1)
epoch2, ord2 := getClusterTime(ct2)
switch {
case epoch1 > epoch2:
return ct1
case epoch1 < epoch2:
return ct2
case ord1 > ord2:
return ct1
case ord1 < ord2:
return ct2
}
return ct1
}
// NewImplicitClientSession creates a new implicit client-side session.
func NewImplicitClientSession(pool *Pool, clientID uuid.UUID) *Client {
// Server-side session checkout for implicit sessions is deferred until after checking out a
// connection, so don't check out a server-side session right now. This will limit the number of
// implicit sessions to no greater than an application's maxPoolSize.
return &Client{
pool: pool,
ClientID: clientID,
IsImplicit: true,
}
}
// NewClientSession creates a new explicit client-side session.
func NewClientSession(pool *Pool, clientID uuid.UUID, opts ...*ClientOptions) (*Client, error) {
c := &Client{
pool: pool,
ClientID: clientID,
}
mergedOpts := mergeClientOptions(opts...)
if mergedOpts.DefaultReadPreference != nil {
c.transactionRp = mergedOpts.DefaultReadPreference
}
if mergedOpts.DefaultReadConcern != nil {
c.transactionRc = mergedOpts.DefaultReadConcern
}
if mergedOpts.DefaultWriteConcern != nil {
c.transactionWc = mergedOpts.DefaultWriteConcern
}
if mergedOpts.Snapshot != nil {
c.Snapshot = *mergedOpts.Snapshot
}
if mergedOpts.SnapshotTime != nil {
c.SnapshotTime = *mergedOpts.SnapshotTime
c.SnapshotTimeSet = true
}
// For explicit sessions, the default for causalConsistency is true, unless Snapshot is
// enabled, then it's false. Set the default and then allow any explicit causalConsistency
// setting to override it.
c.Consistent = !c.Snapshot
if mergedOpts.CausalConsistency != nil {
c.Consistent = *mergedOpts.CausalConsistency
}
if c.Consistent && c.Snapshot {
return nil, errors.New("causal consistency and snapshot cannot both be set for a session")
}
if c.SnapshotTimeSet && !c.Snapshot {
return nil, errors.New("snapshotTime cannot be set when snapshot is false")
}
if err := c.SetServer(); err != nil {
return nil, err
}
return c, nil
}
// SetServer will check out a session from the client session pool.
func (c *Client) SetServer() error {
var err error
c.Server, err = c.pool.GetSession()
return err
}
// AdvanceClusterTime updates the session's cluster time.
func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error {
if c.Terminated {
return ErrSessionEnded
}
c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime)
return nil
}
// AdvanceOperationTime updates the session's operation time.
func (c *Client) AdvanceOperationTime(opTime *bson.Timestamp) error {
if c.Terminated {
return ErrSessionEnded
}
if c.OperationTime == nil {
c.OperationTime = opTime
return nil
}
if opTime.T > c.OperationTime.T {
c.OperationTime = opTime
} else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) {
c.OperationTime = opTime
}
return nil
}
// UpdateUseTime sets the session's last used time to the current time. This must be called whenever the session is
// used to send a command to the server to ensure that the session is not prematurely marked expired in the driver's
// session pool. If the session has already been ended, this method will return ErrSessionEnded.
func (c *Client) UpdateUseTime() error {
if c.Terminated {
return ErrSessionEnded
}
c.updateUseTime()
return nil
}
// UpdateRecoveryToken updates the session's recovery token from the server response.
func (c *Client) UpdateRecoveryToken(response bson.Raw) {
if c == nil {
return
}
token, err := response.LookupErr("recoveryToken")
if err != nil {
return
}
c.RecoveryToken = token.Document()
}
// UpdateSnapshotTime updates the session's value for the atClusterTime field of
// ReadConcern.
func (c *Client) UpdateSnapshotTime(response bsoncore.Document) {
if c == nil || c.SnapshotTimeSet {
// Do nothing if session is nil or snapshot time is already set. The driver
// sends the same atClusterTime for all operations in a snapshot session so
// resetting is a potentially dangerous redundancy.
return
}
subDoc := response
if cur, ok := response.Lookup("cursor").DocumentOK(); ok {
subDoc = cur
}
ssTimeElem, err := subDoc.LookupErr("atClusterTime")
if err != nil {
// atClusterTime not included by the server
return
}
t, i := ssTimeElem.Timestamp()
c.SnapshotTime = bson.Timestamp{
T: t,
I: i,
}
c.SnapshotTimeSet = true
}
// ClearPinnedResources clears the pinned server and/or connection associated with the session.
func (c *Client) ClearPinnedResources() error {
if c == nil {
return nil
}
c.PinnedServerAddr = nil
if c.PinnedConnection != nil {
if err := c.PinnedConnection.UnpinFromTransaction(); err != nil {
return err
}
if err := c.PinnedConnection.Close(); err != nil {
return err
}
}
c.PinnedConnection = nil
return nil
}
// unpinConnection gracefully unpins the connection associated with the session
// if there is one. This is done via the pinned connection's
// UnpinFromTransaction function.
func (c *Client) unpinConnection() error {
if c == nil || c.PinnedConnection == nil {
return nil
}
err := c.PinnedConnection.UnpinFromTransaction()
closeErr := c.PinnedConnection.Close()
if err == nil && closeErr != nil {
err = closeErr
}
c.PinnedConnection = nil
return err
}
// EndSession ends the session.
func (c *Client) EndSession() {
if c.Terminated {
return
}
c.Terminated = true
// Ignore the error when unpinning the connection because we can't do
// anything about it if it doesn't work. Typically the only errors that can
// happen here indicate that something went wrong with the connection state,
// like it wasn't marked as pinned or attempted to return to the wrong pool.
_ = c.unpinConnection()
c.pool.ReturnSession(c.Server)
}
// TransactionInProgress returns true if the client session is in an active transaction.
func (c *Client) TransactionInProgress() bool {
return c.TransactionState == InProgress
}
// TransactionStarting returns true if the client session is starting a transaction.
func (c *Client) TransactionStarting() bool {
return c.TransactionState == Starting
}
// TransactionRunning returns true if the client session has started the transaction
// and it hasn't been committed or aborted
func (c *Client) TransactionRunning() bool {
return c != nil && (c.TransactionState == Starting || c.TransactionState == InProgress)
}
// TransactionCommitted returns true of the client session just committed a transaction.
func (c *Client) TransactionCommitted() bool {
return c.TransactionState == Committed
}
// CheckStartTransaction checks to see if allowed to start transaction and returns
// an error if not allowed
func (c *Client) CheckStartTransaction() error {
if c.TransactionState == InProgress || c.TransactionState == Starting {
return ErrTransactInProgress
}
if c.Snapshot {
return ErrSnapshotTransaction
}
return nil
}
// StartTransaction initializes the transaction options and advances the state machine.
// It does not contact the server to start the transaction.
func (c *Client) StartTransaction(opts *TransactionOptions) error {
err := c.CheckStartTransaction()
if err != nil {
return err
}
c.IncrementTxnNumber()
c.RetryingCommit = false
if opts != nil {
c.CurrentRc = opts.ReadConcern
c.CurrentRp = opts.ReadPreference
c.CurrentWc = opts.WriteConcern
}
if c.CurrentRc == nil {
c.CurrentRc = c.transactionRc
}
if c.CurrentRp == nil {
c.CurrentRp = c.transactionRp
}
if c.CurrentWc == nil {
c.CurrentWc = c.transactionWc
}
if !c.CurrentWc.Acknowledged() {
_ = c.clearTransactionOpts()
return ErrUnackWCUnsupported
}
c.TransactionState = Starting
return c.ClearPinnedResources()
}
// CheckCommitTransaction checks to see if allowed to commit transaction and returns
// an error if not allowed.
func (c *Client) CheckCommitTransaction() error {
switch c.TransactionState {
case None:
return ErrNoTransactStarted
case Aborted:
return ErrCommitAfterAbort
}
return nil
}
// CommitTransaction updates the state for a successfully committed transaction and returns
// an error if not permissible. It does not actually perform the commit.
func (c *Client) CommitTransaction() error {
err := c.CheckCommitTransaction()
if err != nil {
return err
}
c.TransactionState = Committed
return nil
}
// UpdateCommitTransactionWriteConcern will set the write concern to majority.
// This should be called after a commit transaction operation fails with a
// retryable error or after a successful commit transaction operation
//
// Per the transaction specifications, when commitTransaction is retried, if
// the modified write concern does not include a "wtimeout" value, drivers
// MUST apply "wtimeout: 10000" to the write concern in order to avoid waiting
// forever (oruntil a socket timeout) if the majority write concern cannot be
// satisfied. This field abstracts that functionality. For more information,
// see SPEC-1185.
func (c *Client) UpdateCommitTransactionWriteConcern() {
c.CurrentWc = &writeconcern.WriteConcern{
W: "majority",
}
c.CurrentWTimeout = defaultWriteConcernTimeout
}
// CheckAbortTransaction checks to see if allowed to abort transaction and returns
// an error if not allowed.
func (c *Client) CheckAbortTransaction() error {
switch c.TransactionState {
case None:
return ErrNoTransactStarted
case Committed:
return ErrAbortAfterCommit
case Aborted:
return ErrAbortTwice
}
return nil
}
// AbortTransaction updates the state for a successfully aborted transaction and returns
// an error if not permissible. It does not actually perform the abort.
func (c *Client) AbortTransaction() error {
err := c.CheckAbortTransaction()
if err != nil {
return err
}
c.TransactionState = Aborted
return c.clearTransactionOpts()
}
// StartCommand updates the session's internal state at the beginning of an operation. This must be called before
// server selection is done for the operation as the session's state can impact the result of that process.
func (c *Client) StartCommand() error {
if c == nil {
return nil
}
// If we're executing the first operation using this session after a transaction, we must ensure that the session
// is not pinned to any resources.
if !c.TransactionRunning() && !c.Committing && !c.Aborting {
return c.ClearPinnedResources()
}
return nil
}
// ApplyCommand advances the state machine upon command execution. This must be called after server selection is
// complete.
func (c *Client) ApplyCommand(desc description.Server) error {
if c.Committing {
// Do not change state if committing after already committed
return nil
}
switch c.TransactionState {
case Starting:
c.TransactionState = InProgress
// If this is in a transaction and the server is a mongos, pin it
if desc.Kind == description.ServerKindMongos {
c.PinnedServerAddr = &desc.Addr
}
case Committed, Aborted:
c.TransactionState = None
return c.clearTransactionOpts()
}
return nil
}
func (c *Client) clearTransactionOpts() error {
c.RetryingCommit = false
c.Aborting = false
c.Committing = false
c.CurrentWc = nil
c.CurrentRp = nil
c.CurrentRc = nil
c.RecoveryToken = nil
return c.ClearPinnedResources()
}

View File

@@ -0,0 +1,36 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"sync"
"go.mongodb.org/mongo-driver/v2/bson"
)
// ClusterClock represents a logical clock for keeping track of cluster time.
type ClusterClock struct {
clusterTime bson.Raw
lock sync.Mutex
}
// GetClusterTime returns the cluster's current time.
func (cc *ClusterClock) GetClusterTime() bson.Raw {
var ct bson.Raw
cc.lock.Lock()
ct = cc.clusterTime
cc.lock.Unlock()
return ct
}
// AdvanceClusterTime updates the cluster's current time.
func (cc *ClusterClock) AdvanceClusterTime(clusterTime bson.Raw) {
cc.lock.Lock()
cc.clusterTime = MaxClusterTime(cc.clusterTime, clusterTime)
cc.lock.Unlock()
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package session is intended for internal use only. It is made available to
// facilitate use cases that require access to internal MongoDB driver
// functionality and state. The API of this package is not stable and there is
// no backward compatibility guarantee.
//
// WARNING: THIS PACKAGE IS EXPERIMENTAL AND MAY BE MODIFIED OR REMOVED WITHOUT
// NOTICE! USE WITH EXTREME CAUTION!
package session

View File

@@ -0,0 +1,60 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
)
// ClientOptions represents all possible options for creating a client session.
type ClientOptions struct {
CausalConsistency *bool
DefaultReadConcern *readconcern.ReadConcern
DefaultWriteConcern *writeconcern.WriteConcern
DefaultReadPreference *readpref.ReadPref
Snapshot *bool
SnapshotTime *bson.Timestamp
}
// TransactionOptions represents all possible options for starting a transaction in a session.
type TransactionOptions struct {
ReadConcern *readconcern.ReadConcern
WriteConcern *writeconcern.WriteConcern
ReadPreference *readpref.ReadPref
}
func mergeClientOptions(opts ...*ClientOptions) *ClientOptions {
c := &ClientOptions{}
for _, opt := range opts {
if opt == nil {
continue
}
if opt.CausalConsistency != nil {
c.CausalConsistency = opt.CausalConsistency
}
if opt.DefaultReadConcern != nil {
c.DefaultReadConcern = opt.DefaultReadConcern
}
if opt.DefaultReadPreference != nil {
c.DefaultReadPreference = opt.DefaultReadPreference
}
if opt.DefaultWriteConcern != nil {
c.DefaultWriteConcern = opt.DefaultWriteConcern
}
if opt.Snapshot != nil {
c.Snapshot = opt.Snapshot
}
if opt.SnapshotTime != nil {
c.SnapshotTime = opt.SnapshotTime
}
}
return c
}

View File

@@ -0,0 +1,74 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"time"
"go.mongodb.org/mongo-driver/v2/internal/uuid"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
)
// Server is an open session with the server.
type Server struct {
SessionID bsoncore.Document
TxnNumber int64
LastUsed time.Time
Dirty bool
}
// returns whether or not a session has expired given a timeout in minutes
// a session is considered expired if it has less than 1 minute left before becoming stale
func (ss *Server) expired(topoDesc topologyDescription) bool {
// There is no server monitoring in LB mode, so we do not track session timeout minutes from server hello responses
// and never consider sessions to be expired.
if topoDesc.kind == description.TopologyKindLoadBalanced {
return false
}
if topoDesc.timeoutMinutes == nil || *topoDesc.timeoutMinutes <= 0 {
return true
}
timeUnused := time.Since(ss.LastUsed).Minutes()
return timeUnused > float64(*topoDesc.timeoutMinutes-1)
}
// update the last used time for this session.
// must be called whenever this server session is used to send a command to the server.
func (ss *Server) updateUseTime() {
ss.LastUsed = time.Now()
}
func newServerSession() (*Server, error) {
id, err := uuid.New()
if err != nil {
return nil, err
}
idx, idDoc := bsoncore.AppendDocumentStart(nil)
idDoc = bsoncore.AppendBinaryElement(idDoc, "id", UUIDSubtype, id[:])
idDoc, _ = bsoncore.AppendDocumentEnd(idDoc, idx)
return &Server{
SessionID: idDoc,
LastUsed: time.Now(),
}, nil
}
// IncrementTxnNumber increments the transaction number.
func (ss *Server) IncrementTxnNumber() {
ss.TxnNumber++
}
// MarkDirty marks the session as dirty.
func (ss *Server) MarkDirty() {
ss.Dirty = true
}
// UUIDSubtype is the BSON binary subtype that a UUID should be encoded as
const UUIDSubtype byte = 4

View File

@@ -0,0 +1,192 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"sync"
"sync/atomic"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
)
// Node represents a server session in a linked list
type Node struct {
*Server
next *Node
prev *Node
}
// topologyDescription is used to track a subset of the fields present in a description.Topology instance that are
// relevant for determining session expiration.
type topologyDescription struct {
kind description.TopologyKind
timeoutMinutes *int64
}
// Pool is a pool of server sessions that can be reused.
type Pool struct {
// number of sessions checked out of pool (accessed atomically)
checkedOut int64
descChan <-chan description.Topology
head *Node
tail *Node
latestTopology topologyDescription
mutex sync.Mutex // mutex to protect list and sessionTimeout
}
func (p *Pool) createServerSession() (*Server, error) {
s, err := newServerSession()
if err != nil {
return nil, err
}
atomic.AddInt64(&p.checkedOut, 1)
return s, nil
}
// NewPool creates a new server session pool
func NewPool(descChan <-chan description.Topology) *Pool {
p := &Pool{
descChan: descChan,
}
return p
}
// assumes caller has mutex to protect the pool
func (p *Pool) updateTimeout() {
select {
case newDesc := <-p.descChan:
p.latestTopology = topologyDescription{
kind: newDesc.Kind,
timeoutMinutes: newDesc.SessionTimeoutMinutes,
}
default:
// no new description waiting
}
}
// GetSession retrieves an unexpired session from the pool.
func (p *Pool) GetSession() (*Server, error) {
p.mutex.Lock() // prevent changing the linked list while seeing if sessions have expired
defer p.mutex.Unlock()
// empty pool
if p.head == nil && p.tail == nil {
return p.createServerSession()
}
p.updateTimeout()
for p.head != nil {
// pull session from head of queue and return if it is valid for at least 1 more minute
if p.head.expired(p.latestTopology) {
p.head = p.head.next
continue
}
// found unexpired session
session := p.head.Server
if p.head.next != nil {
p.head.next.prev = nil
}
if p.tail == p.head {
p.tail = nil
p.head = nil
} else {
p.head = p.head.next
}
atomic.AddInt64(&p.checkedOut, 1)
return session, nil
}
// no valid session found
p.tail = nil // empty list
return p.createServerSession()
}
// ReturnSession returns a session to the pool if it has not expired.
func (p *Pool) ReturnSession(ss *Server) {
if ss == nil {
return
}
p.mutex.Lock()
defer p.mutex.Unlock()
atomic.AddInt64(&p.checkedOut, -1)
p.updateTimeout()
// check sessions at end of queue for expired
// stop checking after hitting the first valid session
for p.tail != nil && p.tail.expired(p.latestTopology) {
if p.tail.prev != nil {
p.tail.prev.next = nil
}
p.tail = p.tail.prev
}
// session expired
if ss.expired(p.latestTopology) {
return
}
// session is dirty
if ss.Dirty {
return
}
newNode := &Node{
Server: ss,
next: nil,
prev: nil,
}
// empty list
if p.tail == nil {
p.head = newNode
p.tail = newNode
return
}
// at least 1 valid session in list
newNode.next = p.head
p.head.prev = newNode
p.head = newNode
}
// IDSlice returns a slice of session IDs for each session in the pool
func (p *Pool) IDSlice() []bsoncore.Document {
p.mutex.Lock()
defer p.mutex.Unlock()
var ids []bsoncore.Document
for node := p.head; node != nil; node = node.next {
ids = append(ids, node.SessionID)
}
return ids
}
// String implements the Stringer interface
func (p *Pool) String() string {
p.mutex.Lock()
defer p.mutex.Unlock()
s := ""
for head := p.head; head != nil; head = head.next {
s += head.SessionID.String() + "\n"
}
return s
}
// CheckedOut returns number of sessions checked out from pool.
func (p *Pool) CheckedOut() int64 {
return atomic.LoadInt64(&p.checkedOut)
}

Some files were not shown because too many files have changed in this diff Show More