// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package mongo import ( "bytes" "context" "errors" "fmt" "io" "net" "reflect" "strconv" "strings" "go.mongodb.org/mongo-driver/v2/internal/codecutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/bson" ) var defaultRegistry = bson.NewRegistry() // Dialer is used to make network connections. type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } // Pipeline is a type that makes creating aggregation pipelines easier. It is a // helper and is intended for serializing to BSON. // // Example usage: // // mongo.Pipeline{ // {{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}}, // {{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}}, // } type Pipeline []bson.D // getEncoder takes a writer, BSON options, and a BSON registry and returns a properly configured // bson.Encoder that writes to the given writer. func getEncoder( w io.Writer, opts *options.BSONOptions, reg *bson.Registry, ) *bson.Encoder { vw := bson.NewDocumentWriter(w) enc := bson.NewEncoder(vw) if opts != nil { if opts.ErrorOnInlineDuplicates { enc.ErrorOnInlineDuplicates() } if opts.IntMinSize { enc.IntMinSize() } if opts.NilByteSliceAsEmpty { enc.NilByteSliceAsEmpty() } if opts.NilMapAsEmpty { enc.NilMapAsEmpty() } if opts.NilSliceAsEmpty { enc.NilSliceAsEmpty() } if opts.OmitZeroStruct { enc.OmitZeroStruct() } if opts.OmitEmpty { enc.OmitEmpty() } if opts.StringifyMapKeysWithFmt { enc.StringifyMapKeysWithFmt() } if opts.UseJSONStructTags { enc.UseJSONStructTags() } } if reg != nil { enc.SetRegistry(reg) } return enc } // newEncoderFn will return a function for constructing an encoder based on the // provided codec options. func newEncoderFn(opts *options.BSONOptions, registry *bson.Registry) codecutil.EncoderFn { return func(w io.Writer) *bson.Encoder { return getEncoder(w, opts, registry) } } // marshal marshals the given value as a BSON document. Byte slices are always converted to a // bson.Raw before marshaling. // // If bsonOpts and registry are specified, the encoder is configured with the requested behaviors. // If they are nil, the default behaviors are used. func marshal( val any, bsonOpts *options.BSONOptions, registry *bson.Registry, ) (bsoncore.Document, error) { if registry == nil { registry = defaultRegistry } if val == nil { return nil, ErrNilDocument } if bs, ok := val.([]byte); ok { // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. val = bson.Raw(bs) } buf := new(bytes.Buffer) enc := getEncoder(buf, bsonOpts, registry) err := enc.Encode(val) if err != nil { return nil, MarshalError{Value: val, Err: err} } return buf.Bytes(), nil } // ensureID inserts the given ObjectID as an element named "_id" at the // beginning of the given BSON document if there is not an "_id" already. // If the given ObjectID is bson.NilObjectID, a new object ID will be // generated with time.Now(). // // If there is already an element named "_id", the document is not modified. It // returns the resulting document and the decoded Go value of the "_id" element. func ensureID( doc bsoncore.Document, oid bson.ObjectID, bsonOpts *options.BSONOptions, reg *bson.Registry, ) (bsoncore.Document, any, error) { if reg == nil { reg = defaultRegistry } // Try to find the "_id" element. If it exists, try to unmarshal just the // "_id" field as an any and return it along with the unmodified // BSON document. if _, err := doc.LookupErr("_id"); err == nil { var id struct { ID any `bson:"_id"` } dec := getDecoder(doc, bsonOpts, reg) err = dec.Decode(&id) if err != nil { return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) } return doc, id.ID, nil } // We couldn't find an "_id" element, so add one with the value of the // provided ObjectID. olddoc := doc // Reserve an extra 17 bytes for the "_id" field we're about to add: // type (1) + "_id" (3) + terminator (1) + object ID (12) const extraSpace = 17 doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace) _, doc = bsoncore.ReserveLength(doc) if oid.IsZero() { oid = bson.NewObjectID() } doc = bsoncore.AppendObjectIDElement(doc, "_id", oid) // Remove and re-write the BSON document length header. const int32Len = 4 doc = append(doc, olddoc[int32Len:]...) doc = bsoncore.UpdateLength(doc, 0, int32(len(doc))) return doc, oid, nil } func ensureDollarKey(doc bsoncore.Document) error { firstElem, err := doc.IndexErr(0) if err != nil { return errors.New("update document must have at least one element") } if !strings.HasPrefix(firstElem.Key(), "$") { return errors.New("update document must contain key beginning with '$'") } return nil } func ensureNoDollarKey(doc bsoncore.Document) error { if elem, err := doc.IndexErr(0); err == nil && strings.HasPrefix(elem.Key(), "$") { return errors.New("replacement document cannot contain keys beginning with '$'") } return nil } func marshalAggregatePipeline( pipeline any, bsonOpts *options.BSONOptions, registry *bson.Registry, ) (bsoncore.Document, bool, error) { switch t := pipeline.(type) { case bson.ValueMarshaler: btype, val, err := t.MarshalBSONValue() if err != nil { return nil, false, err } if typ := bson.Type(btype); typ != bson.TypeArray { return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", typ, bson.TypeArray) } var hasOutputStage bool pipelineDoc := bsoncore.Document(val) values, _ := pipelineDoc.Values() if pipelineLen := len(values); pipelineLen > 0 { if finalDoc, ok := values[pipelineLen-1].DocumentOK(); ok { if elem, err := finalDoc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") { hasOutputStage = true } } } return pipelineDoc, hasOutputStage, nil default: val := reflect.ValueOf(t) if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) { return nil, false, fmt.Errorf("can only marshal slices and arrays into aggregation pipelines, but got %v", val.Kind()) } var hasOutputStage bool valLen := val.Len() switch t := pipeline.(type) { // Explicitly forbid non-empty pipelines that are semantically single documents // and are implemented as slices. case bson.D, bson.Raw, bsoncore.Document: if valLen > 0 { return nil, false, fmt.Errorf("%T is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead", t) } // bsoncore.Arrays do not need to be marshaled. Only check validity and presence of output stage. case bsoncore.Array: if err := t.Validate(); err != nil { return nil, false, err } values, err := t.Values() if err != nil { return nil, false, err } numVals := len(values) if numVals == 0 { return bsoncore.Document(t), false, nil } // If not empty, check if first value of the last stage is $out or $merge. if lastStage, ok := values[numVals-1].DocumentOK(); ok { if elem, err := lastStage.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") { hasOutputStage = true } } return bsoncore.Document(t), hasOutputStage, nil } aidx, arr := bsoncore.AppendArrayStart(nil) for idx := 0; idx < valLen; idx++ { doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry) if err != nil { return nil, false, err } if idx == valLen-1 { if elem, err := doc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") { hasOutputStage = true } } arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc) } arr, _ = bsoncore.AppendArrayEnd(arr, aidx) return arr, hasOutputStage, nil } } func marshalUpdateValue( update any, bsonOpts *options.BSONOptions, registry *bson.Registry, dollarKeysAllowed bool, ) (bsoncore.Value, error) { documentCheckerFunc := ensureDollarKey if !dollarKeysAllowed { documentCheckerFunc = ensureNoDollarKey } var u bsoncore.Value var err error switch t := update.(type) { case nil: return u, ErrNilDocument case bson.D: u.Type = bsoncore.TypeEmbeddedDocument u.Data, err = marshal(update, bsonOpts, registry) if err != nil { return u, err } return u, documentCheckerFunc(u.Data) case bson.Raw: u.Type = bsoncore.TypeEmbeddedDocument u.Data = t return u, documentCheckerFunc(u.Data) case bsoncore.Document: u.Type = bsoncore.TypeEmbeddedDocument u.Data = t return u, documentCheckerFunc(u.Data) case []byte: u.Type = bsoncore.TypeEmbeddedDocument u.Data = t return u, documentCheckerFunc(u.Data) case bson.Marshaler: u.Type = bsoncore.TypeEmbeddedDocument u.Data, err = t.MarshalBSON() if err != nil { return u, err } return u, documentCheckerFunc(u.Data) case bson.ValueMarshaler: tt, data, err := t.MarshalBSONValue() u.Type = bsoncore.Type(tt) u.Data = data if err != nil { return u, err } if u.Type != bsoncore.TypeArray && u.Type != bsoncore.TypeEmbeddedDocument { return u, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v or %v", u.Type, bsoncore.TypeArray, bsoncore.TypeEmbeddedDocument) } return u, err default: val := reflect.ValueOf(t) if !val.IsValid() { return u, fmt.Errorf("can only marshal slices and arrays into update pipelines, but got %v", val.Kind()) } if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { u.Type = bsoncore.TypeEmbeddedDocument u.Data, err = marshal(update, bsonOpts, registry) if err != nil { return u, err } return u, documentCheckerFunc(u.Data) } u.Type = bsoncore.TypeArray aidx, arr := bsoncore.AppendArrayStart(nil) valLen := val.Len() for idx := 0; idx < valLen; idx++ { doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry) if err != nil { return u, err } if err := documentCheckerFunc(doc); err != nil { return u, err } arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc) } u.Data, _ = bsoncore.AppendArrayEnd(arr, aidx) return u, err } } func marshalValue( val any, bsonOpts *options.BSONOptions, registry *bson.Registry, ) (bsoncore.Value, error) { return codecutil.MarshalValue(val, newEncoderFn(bsonOpts, registry)) } // Build the aggregation pipeline for the CountDocument command. func countDocumentsAggregatePipeline( filter any, encOpts *options.BSONOptions, registry *bson.Registry, args *options.CountOptions, ) (bsoncore.Document, error) { filterDoc, err := marshal(filter, encOpts, registry) if err != nil { return nil, err } aidx, arr := bsoncore.AppendArrayStart(nil) didx, arr := bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(0)) arr = bsoncore.AppendDocumentElement(arr, "$match", filterDoc) arr, _ = bsoncore.AppendDocumentEnd(arr, didx) index := 1 if args != nil { if args.Skip != nil { didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index)) arr = bsoncore.AppendInt64Element(arr, "$skip", *args.Skip) arr, _ = bsoncore.AppendDocumentEnd(arr, didx) index++ } if args.Limit != nil { didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index)) arr = bsoncore.AppendInt64Element(arr, "$limit", *args.Limit) arr, _ = bsoncore.AppendDocumentEnd(arr, didx) index++ } } didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index)) iidx, arr := bsoncore.AppendDocumentElementStart(arr, "$group") arr = bsoncore.AppendInt32Element(arr, "_id", 1) iiidx, arr := bsoncore.AppendDocumentElementStart(arr, "n") arr = bsoncore.AppendInt32Element(arr, "$sum", 1) arr, _ = bsoncore.AppendDocumentEnd(arr, iiidx) arr, _ = bsoncore.AppendDocumentEnd(arr, iidx) arr, _ = bsoncore.AppendDocumentEnd(arr, didx) return bsoncore.AppendArrayEnd(arr, aidx) }