From e1bb7d65a842b3e9521bc2ad79d34a70cde2352f Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Tue, 17 Oct 2023 14:15:37 -0700 Subject: [PATCH] Implement rust repeated scalars for cpp and upb PiperOrigin-RevId: 574261929 --- Cargo.bazel.lock | 65 ++++++- Cargo.lock | 8 +- WORKSPACE | 3 + rust/BUILD | 9 +- rust/cpp.rs | 137 +++++++++++++- rust/cpp_kernel/BUILD | 3 +- rust/cpp_kernel/cpp_api.cc | 35 ++++ rust/internal.rs | 14 ++ rust/repeated.rs | 119 ++++++++++++ rust/shared.rs | 2 + rust/test/cpp/interop/test_utils.cc | 2 +- rust/test/shared/BUILD | 6 + rust/test/shared/accessors_test.rs | 49 +++++ rust/upb.rs | 176 +++++++++++++++++- rust/upb_kernel/BUILD | 1 + rust/upb_kernel/upb_api.c | 3 +- src/google/protobuf/compiler/rust/BUILD.bazel | 1 + .../rust/accessors/accessor_generator.h | 8 + .../compiler/rust/accessors/accessors.cc | 13 +- .../rust/accessors/repeated_scalar.cc | 156 ++++++++++++++++ src/google/protobuf/compiler/rust/naming.cc | 39 +++- 21 files changed, 831 insertions(+), 18 deletions(-) create mode 100644 rust/cpp_kernel/cpp_api.cc create mode 100644 rust/repeated.rs create mode 100644 src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc diff --git a/Cargo.bazel.lock b/Cargo.bazel.lock index 80674909ca..3839525cf1 100644 --- a/Cargo.bazel.lock +++ b/Cargo.bazel.lock @@ -1,5 +1,5 @@ { - "checksum": "8bc2d235f612e77f4dca1b6886cc8bd14df348168fea27a687805ed9518a8f1a", + "checksum": "641f887b045ff0fc19f64df79b53d96d77d1c03c96069036d84bd1104ddc0000", "crates": { "aho-corasick 1.1.2": { "name": "aho-corasick", @@ -108,6 +108,15 @@ "selects": {} }, "edition": "2018", + "proc_macro_deps": { + "common": [ + { + "id": "paste 1.0.14", + "target": "paste" + } + ], + "selects": {} + }, "version": "0.0.1" }, "license": null @@ -318,6 +327,59 @@ }, "license": "MIT OR Apache-2.0" }, + "paste 1.0.14": { + "name": "paste", + "version": "1.0.14", + "repository": { + "Http": { + "url": "https://crates.io/api/v1/crates/paste/1.0.14/download", + "sha256": "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + } + }, + "targets": [ + { + "ProcMacro": { + "crate_name": "paste", + "crate_root": "src/lib.rs", + "srcs": [ + "**/*.rs" + ] + } + }, + { + "BuildScript": { + "crate_name": "build_script_build", + "crate_root": "build.rs", + "srcs": [ + "**/*.rs" + ] + } + } + ], + "library_target_name": "paste", + "common_attrs": { + "compile_data_glob": [ + "**" + ], + "deps": { + "common": [ + { + "id": "paste 1.0.14", + "target": "build_script_build" + } + ], + "selects": {} + }, + "edition": "2018", + "version": "1.0.14" + }, + "build_script_attrs": { + "data_glob": [ + "**" + ] + }, + "license": "MIT OR Apache-2.0" + }, "proc-macro2 1.0.69": { "name": "proc-macro2", "version": "1.0.69", @@ -769,4 +831,3 @@ }, "conditions": {} } - diff --git a/Cargo.lock b/Cargo.lock index e075e9722c..ea70571b3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,6 +22,7 @@ name = "direct-cargo-bazel-deps" version = "0.0.1" dependencies = [ "googletest", + "paste", ] [[package]] @@ -61,6 +62,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + [[package]] name = "proc-macro2" version = "1.0.69" @@ -130,4 +137,3 @@ name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - diff --git a/WORKSPACE b/WORKSPACE index 2c665be1a8..d0d66ec54f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -193,6 +193,9 @@ crates_repository( "googletest": crate.spec( version = ">0.0.0", ), + "paste": crate.spec( + version = ">=1", + ), }, ) diff --git a/rust/BUILD b/rust/BUILD index dd23b44abe..3e0ca71ed6 100644 --- a/rust/BUILD +++ b/rust/BUILD @@ -52,6 +52,7 @@ PROTOBUF_SHARED = [ "optional.rs", "primitive.rs", "proxied.rs", + "repeated.rs", "shared.rs", "string.rs", "vtable.rs", @@ -92,8 +93,14 @@ rust_library( name = "protobuf_cpp", srcs = PROTOBUF_SHARED + ["cpp.rs"], crate_root = "shared.rs", + proc_macro_deps = [ + "@crate_index//:paste", + ], rustc_flags = ["--cfg=cpp_kernel"], - deps = [":utf8"], + deps = [ + ":utf8", + "//rust/cpp_kernel:cpp_api", + ], ) rust_test( diff --git a/rust/cpp.rs b/rust/cpp.rs index c2af02d7c0..6ef124070e 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -7,7 +7,8 @@ // Rust Protobuf runtime using the C++ kernel. -use crate::__internal::{Private, RawArena, RawMessage}; +use crate::__internal::{Private, RawArena, RawMessage, RawRepeatedField}; +use paste::paste; use std::alloc::Layout; use std::cell::UnsafeCell; use std::fmt; @@ -35,6 +36,7 @@ pub struct Arena { impl Arena { /// Allocates a fresh arena. #[inline] + #[allow(clippy::new_without_default)] pub fn new() -> Self { Self { ptr: NonNull::dangling(), _not_sync: PhantomData } } @@ -182,6 +184,116 @@ pub fn copy_bytes_in_arena_if_needed_by_runtime<'a>( val } +/// RepeatedField impls delegate out to `extern "C"` functions exposed by +/// `cpp_api.h` and store either a RepeatedField* or a RepeatedPtrField* +/// depending on the type. +/// +/// Note: even though this type is `Copy`, it should only be copied by +/// protobuf internals that can maintain mutation invariants: +/// +/// - No concurrent mutation for any two fields in a message: this means +/// mutators cannot be `Send` but are `Sync`. +/// - If there are multiple accessible `Mut` to a single message at a time, they +/// must be different fields, and not be in the same oneof. As such, a `Mut` +/// cannot be `Clone` but *can* reborrow itself with `.as_mut()`, which +/// converts `&'b mut Mut<'a, T>` to `Mut<'b, T>`. +#[derive(Clone, Copy)] +pub struct RepeatedField<'msg, T: ?Sized> { + inner: RepeatedFieldInner<'msg>, + _phantom: PhantomData<&'msg mut T>, +} + +/// CPP runtime-specific arguments for initializing a RepeatedField. +/// See RepeatedField comment about mutation invariants for when this type can +/// be copied. +#[derive(Clone, Copy)] +pub struct RepeatedFieldInner<'msg> { + pub raw: RawRepeatedField, + pub _phantom: PhantomData<&'msg ()>, +} + +impl<'msg, T: ?Sized> RepeatedField<'msg, T> { + pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self { + RepeatedField { inner, _phantom: PhantomData } + } +} +impl<'msg> RepeatedField<'msg, i32> {} + +pub trait RepeatedScalarOps { + fn new_repeated_field() -> RawRepeatedField; + fn push(f: RawRepeatedField, v: Self); + fn len(f: RawRepeatedField) -> usize; + fn get(f: RawRepeatedField, i: usize) -> Self; + fn set(f: RawRepeatedField, i: usize, v: Self); +} + +macro_rules! impl_repeated_scalar_ops { + ($($t: ty),*) => { + paste! { $( + extern "C" { + fn [< __pb_rust_RepeatedField_ $t _new >]() -> RawRepeatedField; + fn [< __pb_rust_RepeatedField_ $t _add >](f: RawRepeatedField, v: $t); + fn [< __pb_rust_RepeatedField_ $t _size >](f: RawRepeatedField) -> usize; + fn [< __pb_rust_RepeatedField_ $t _get >](f: RawRepeatedField, i: usize) -> $t; + fn [< __pb_rust_RepeatedField_ $t _set >](f: RawRepeatedField, i: usize, v: $t); + } + impl RepeatedScalarOps for $t { + fn new_repeated_field() -> RawRepeatedField { + unsafe { [< __pb_rust_RepeatedField_ $t _new >]() } + } + fn push(f: RawRepeatedField, v: Self) { + unsafe { [< __pb_rust_RepeatedField_ $t _add >](f, v) } + } + fn len(f: RawRepeatedField) -> usize { + unsafe { [< __pb_rust_RepeatedField_ $t _size >](f) } + } + fn get(f: RawRepeatedField, i: usize) -> Self { + unsafe { [< __pb_rust_RepeatedField_ $t _get >](f, i) } + } + fn set(f: RawRepeatedField, i: usize, v: Self) { + unsafe { [< __pb_rust_RepeatedField_ $t _set >](f, i, v) } + } + } + )* } + }; +} + +impl_repeated_scalar_ops!(i32, u32, i64, u64, f32, f64, bool); + +impl<'msg, T: RepeatedScalarOps> RepeatedField<'msg, T> { + #[allow(clippy::new_without_default, dead_code)] + /// new() is not currently used in our normal pathways, it is only used + /// for testing. Existing `RepeatedField<>`s are owned by, and retrieved + /// from, the containing `Message`. + pub fn new() -> Self { + Self::from_inner( + Private, + RepeatedFieldInner::<'msg> { raw: T::new_repeated_field(), _phantom: PhantomData }, + ) + } + pub fn push(&mut self, val: T) { + T::push(self.inner.raw, val) + } + pub fn len(&self) -> usize { + T::len(self.inner.raw) + } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn get(&self, index: usize) -> Option { + if index >= self.len() { + return None; + } + Some(T::get(self.inner.raw, index)) + } + pub fn set(&mut self, index: usize, val: T) { + if index >= self.len() { + return; + } + T::set(self.inner.raw, index, val) + } +} + #[cfg(test)] mod tests { use super::*; @@ -201,4 +313,27 @@ mod tests { let serialized_data = SerializedData { data: NonNull::new(ptr).unwrap(), len: len }; assert_eq!(&*serialized_data, b"Hello world"); } + + #[test] + fn repeated_field() { + let mut r = RepeatedField::::new(); + assert_eq!(r.len(), 0); + r.push(32); + assert_eq!(r.get(0), Some(32)); + + let mut r = RepeatedField::::new(); + assert_eq!(r.len(), 0); + r.push(32); + assert_eq!(r.get(0), Some(32)); + + let mut r = RepeatedField::::new(); + assert_eq!(r.len(), 0); + r.push(0.1234f64); + assert_eq!(r.get(0), Some(0.1234)); + + let mut r = RepeatedField::::new(); + assert_eq!(r.len(), 0); + r.push(true); + assert_eq!(r.get(0), Some(true)); + } } diff --git a/rust/cpp_kernel/BUILD b/rust/cpp_kernel/BUILD index 245772c834..d10f9e7db8 100644 --- a/rust/cpp_kernel/BUILD +++ b/rust/cpp_kernel/BUILD @@ -4,13 +4,14 @@ load("@rules_rust//rust:defs.bzl", "rust_library") cc_library( name = "cpp_api", + srcs = ["cpp_api.cc"], hdrs = ["cpp_api.h"], visibility = [ "//src/google/protobuf:__subpackages__", "//rust:__subpackages__", ], deps = [ - ":rust_alloc_for_cpp_api", + ":rust_alloc_for_cpp_api", # buildcleaner: keep "//:protobuf_nowkt", ], ) diff --git a/rust/cpp_kernel/cpp_api.cc b/rust/cpp_kernel/cpp_api.cc new file mode 100644 index 0000000000..8ff79d8fa9 --- /dev/null +++ b/rust/cpp_kernel/cpp_api.cc @@ -0,0 +1,35 @@ +#include "google/protobuf/repeated_field.h" + +extern "C" { + +#define expose_repeated_field_methods(ty, rust_ty) \ + google::protobuf::RepeatedField* __pb_rust_RepeatedField_##rust_ty##_new() { \ + return new google::protobuf::RepeatedField(); \ + } \ + void __pb_rust_RepeatedField_##rust_ty##_add(google::protobuf::RepeatedField* r, \ + ty val) { \ + r->Add(val); \ + } \ + size_t __pb_rust_RepeatedField_##rust_ty##_size( \ + google::protobuf::RepeatedField* r) { \ + return r->size(); \ + } \ + ty __pb_rust_RepeatedField_##rust_ty##_get(google::protobuf::RepeatedField* r, \ + size_t index) { \ + return r->Get(index); \ + } \ + void __pb_rust_RepeatedField_##rust_ty##_set(google::protobuf::RepeatedField* r, \ + size_t index, ty val) { \ + return r->Set(index, val); \ + } + +expose_repeated_field_methods(int32_t, i32); +expose_repeated_field_methods(uint32_t, u32); +expose_repeated_field_methods(float, f32); +expose_repeated_field_methods(double, f64); +expose_repeated_field_methods(bool, bool); +expose_repeated_field_methods(uint64_t, u64); +expose_repeated_field_methods(int64_t, i64); + +#undef expose_repeated_field_methods +} diff --git a/rust/internal.rs b/rust/internal.rs index 1e0f536b19..e56c9dc555 100644 --- a/rust/internal.rs +++ b/rust/internal.rs @@ -51,6 +51,17 @@ mod _opaque_pointees { _data: [u8; 0], _marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>, } + + /// Opaque pointee for [`RawRepeatedField`] + /// + /// This type is not meant to be dereferenced in Rust code. + /// It is only meant to provide type safety for raw pointers + /// which are manipulated behind FFI. + #[repr(C)] + pub struct RawRepeatedFieldData { + _data: [u8; 0], + _marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>, + } } /// A raw pointer to the underlying message for this runtime. @@ -59,6 +70,9 @@ pub type RawMessage = NonNull<_opaque_pointees::RawMessageData>; /// A raw pointer to the underlying arena for this runtime. pub type RawArena = NonNull<_opaque_pointees::RawArenaData>; +/// A raw pointer to the underlying repeated field container for this runtime. +pub type RawRepeatedField = NonNull<_opaque_pointees::RawRepeatedFieldData>; + /// Represents an ABI-stable version of `NonNull<[u8]>`/`string_view` (a /// borrowed slice of bytes) for FFI use only. /// diff --git a/rust/repeated.rs b/rust/repeated.rs new file mode 100644 index 0000000000..b824de8399 --- /dev/null +++ b/rust/repeated.rs @@ -0,0 +1,119 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2023 Google LLC. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +/// Repeated scalar fields are implemented around the runtime-specific +/// `RepeatedField` struct. `RepeatedField` stores an opaque pointer to the +/// runtime-specific representation of a repeated scalar (`upb_Array*` on upb, +/// and `RepeatedField*` on cpp). +use std::marker::PhantomData; + +use crate::{ + __internal::{Private, RawRepeatedField}, + __runtime::{RepeatedField, RepeatedFieldInner}, +}; + +#[derive(Clone, Copy)] +pub struct RepeatedFieldRef<'a> { + pub repeated_field: RawRepeatedField, + pub _phantom: PhantomData<&'a mut ()>, +} + +unsafe impl<'a> Send for RepeatedFieldRef<'a> {} +unsafe impl<'a> Sync for RepeatedFieldRef<'a> {} + +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct RepeatedView<'a, T: ?Sized> { + inner: RepeatedField<'a, T>, +} + +impl<'msg, T: ?Sized> RepeatedView<'msg, T> { + pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self { + Self { inner: RepeatedField::<'msg>::from_inner(_private, inner) } + } +} + +pub struct RepeatedFieldIter<'a, T> { + inner: RepeatedField<'a, T>, + current_index: usize, +} + +impl<'a, T> std::fmt::Debug for RepeatedView<'a, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("RepeatedView").finish() + } +} + +#[repr(transparent)] +pub struct RepeatedMut<'a, T: ?Sized> { + inner: RepeatedField<'a, T>, +} + +impl<'msg, T: ?Sized> RepeatedMut<'msg, T> { + pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self { + Self { inner: RepeatedField::from_inner(_private, inner) } + } +} + +impl<'a, T> std::ops::Deref for RepeatedMut<'a, T> { + type Target = RepeatedView<'a, T>; + fn deref(&self) -> &Self::Target { + // SAFETY: + // - `Repeated{View,Mut}<'a, T>` are both `#[repr(transparent)]` over + // `RepeatedField<'a, T>`. + // - `RepeatedField` is a type alias for `NonNull`. + unsafe { &*(self as *const Self as *const RepeatedView<'a, T>) } + } +} + +macro_rules! impl_repeated_primitives { + ($($t:ty),*) => { + $( + impl<'a> RepeatedView<'a, $t> { + pub fn len(&self) -> usize { + self.inner.len() + } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn get(&self, index: usize) -> Option<$t> { + self.inner.get(index) + } + } + + impl<'a> RepeatedMut<'a, $t> { + pub fn push(&mut self, val: $t) { + self.inner.push(val) + } + pub fn set(&mut self, index: usize, val: $t) { + self.inner.set(index, val) + } + } + + impl<'a> std::iter::Iterator for RepeatedFieldIter<'a, $t> { + type Item = $t; + fn next(&mut self) -> Option { + let val = self.inner.get(self.current_index); + if val.is_some() { + self.current_index += 1; + } + val + } + } + + impl<'a> std::iter::IntoIterator for RepeatedView<'a, $t> { + type Item = $t; + type IntoIter = RepeatedFieldIter<'a, $t>; + fn into_iter(self) -> Self::IntoIter { + RepeatedFieldIter { inner: self.inner, current_index: 0 } + } + } + )* + } +} + +impl_repeated_primitives!(i32, u32, bool, f32, f64, i64, u64); diff --git a/rust/shared.rs b/rust/shared.rs index 3c4408d943..f8a9d117d9 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -22,6 +22,7 @@ pub mod __public { pub use crate::proxied::{ Mut, MutProxy, Proxied, ProxiedWithPresence, SettableValue, View, ViewProxy, }; + pub use crate::repeated::{RepeatedFieldRef, RepeatedMut, RepeatedView}; pub use crate::string::{BytesMut, ProtoStr, ProtoStrMut}; } pub use __public::*; @@ -46,6 +47,7 @@ mod macros; mod optional; mod primitive; mod proxied; +mod repeated; mod string; mod vtable; diff --git a/rust/test/cpp/interop/test_utils.cc b/rust/test/cpp/interop/test_utils.cc index d5c3784df1..5c27a95c9b 100644 --- a/rust/test/cpp/interop/test_utils.cc +++ b/rust/test/cpp/interop/test_utils.cc @@ -8,7 +8,7 @@ #include #include "absl/strings/string_view.h" -#include "google/protobuf/rust/cpp_kernel/cpp_api.h" +#include "rust/cpp_kernel/cpp_api.h" #include "google/protobuf/unittest.pb.h" extern "C" void MutateTestAllTypes(protobuf_unittest::TestAllTypes* msg) { diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD index e5be7b7511..79d43ddd4e 100644 --- a/rust/test/shared/BUILD +++ b/rust/test/shared/BUILD @@ -151,6 +151,9 @@ rust_test( "//rust:protobuf_cpp": "protobuf", "//rust/test/shared:matchers_cpp": "matchers", }, + proc_macro_deps = [ + "@crate_index//:paste", + ], tags = [ # TODO: Enable testing on arm once we support sanitizers for Rust on Arm. "not_build:arm", @@ -170,6 +173,9 @@ rust_test( "//rust:protobuf_upb": "protobuf", "//rust/test/shared:matchers_upb": "matchers", }, + proc_macro_deps = [ + "@crate_index//:paste", + ], tags = [ # TODO: Enable testing on arm once we support sanitizers for Rust on Arm. "not_build:arm", diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs index 40098b18ef..910f9f133e 100644 --- a/rust/test/shared/accessors_test.rs +++ b/rust/test/shared/accessors_test.rs @@ -9,6 +9,7 @@ use googletest::prelude::*; use matchers::{is_set, is_unset}; +use paste::paste; use protobuf::Optional; use unittest_proto::proto2_unittest::{TestAllTypes, TestAllTypes_}; @@ -398,3 +399,51 @@ fn test_oneof_accessors() { // This should show it set to the OneofBytes but its not supported yet. assert_that!(msg.oneof_field(), matches_pattern!(not_set(_))); } + +macro_rules! generate_repeated_numeric_test { + ($(($t: ty, $field: ident)),*) => { + paste! { $( + #[test] + fn [< test_repeated_ $field _accessors >]() { + let mut msg = TestAllTypes::new(); + assert_that!(msg.[< repeated_ $field >]().len(), eq(0)); + assert_that!(msg.[]().get(0), none()); + + let mut mutator = msg.[](); + mutator.push(1 as $t); + assert_that!(mutator.len(), eq(1)); + assert_that!(mutator.get(0), some(eq(1 as $t))); + mutator.set(0, 2 as $t); + assert_that!(mutator.get(0), some(eq(2 as $t))); + mutator.push(1 as $t); + + assert_that!(mutator.into_iter().collect::>(), eq(vec![2 as $t, 1 as $t])); + } + )* } + }; +} + +generate_repeated_numeric_test!( + (i32, int32), + (u32, uint32), + (i64, int64), + (u64, uint64), + (f32, float), + (f64, double) +); + +#[test] +fn test_repeated_bool_accessors() { + let mut msg = TestAllTypes::new(); + assert_that!(msg.repeated_bool().len(), eq(0)); + assert_that!(msg.repeated_bool().get(0), none()); + + let mut mutator = msg.repeated_bool_mut(); + mutator.push(true); + assert_that!(mutator.len(), eq(1)); + assert_that!(mutator.get(0), some(eq(true))); + mutator.set(0, false); + assert_that!(mutator.get(0), some(eq(false))); + mutator.push(true); + assert_that!(mutator.into_iter().collect::>(), eq(vec![false, true])); +} diff --git a/rust/upb.rs b/rust/upb.rs index cd5cc77e74..57f9c5512b 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -7,7 +7,7 @@ //! UPB FFI wrapper code for use by Rust Protobuf. -use crate::__internal::{Private, RawArena, RawMessage}; +use crate::__internal::{Private, PtrAndLen, RawArena, RawMessage, RawRepeatedField}; use std::alloc; use std::alloc::Layout; use std::cell::UnsafeCell; @@ -284,6 +284,149 @@ pub fn copy_bytes_in_arena_if_needed_by_runtime<'a>( } } +/// RepeatedFieldInner contains a `upb_Array*` as well as a reference to an +/// `Arena`, most likely that of the containing `Message`. upb requires an Arena +/// to perform mutations on a repeated field. +#[derive(Clone, Copy, Debug)] +pub struct RepeatedFieldInner<'msg> { + pub raw: RawRepeatedField, + pub arena: &'msg Arena, +} + +#[derive(Clone, Copy, Debug)] +pub struct RepeatedField<'msg, T: ?Sized> { + inner: RepeatedFieldInner<'msg>, + _phantom: PhantomData<&'msg mut T>, +} + +impl<'msg, T: ?Sized> RepeatedField<'msg, T> { + pub fn len(&self) -> usize { + unsafe { upb_Array_Size(self.inner.raw) } + } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self { + Self { inner, _phantom: PhantomData } + } +} + +// Transcribed from google3/third_party/upb/upb/message/value.h +#[repr(C)] +#[derive(Clone, Copy)] +union upb_MessageValue { + bool_val: bool, + float_val: std::ffi::c_float, + double_val: std::ffi::c_double, + uint32_val: u32, + int32_val: i32, + uint64_val: u64, + int64_val: i64, + array_val: *const std::ffi::c_void, + map_val: *const std::ffi::c_void, + msg_val: *const std::ffi::c_void, + str_val: PtrAndLen, +} + +// Transcribed from google3/third_party/upb/upb/base/descriptor_constants.h +#[repr(C)] +#[allow(dead_code)] +enum UpbCType { + Bool = 1, + Float = 2, + Int32 = 3, + UInt32 = 4, + Enum = 5, + Message = 6, + Double = 7, + Int64 = 8, + UInt64 = 9, + String = 10, + Bytes = 11, +} + +extern "C" { + #[allow(dead_code)] + fn upb_Array_New(a: RawArena, r#type: std::ffi::c_int) -> RawRepeatedField; + fn upb_Array_Size(arr: RawRepeatedField) -> usize; + fn upb_Array_Set(arr: RawRepeatedField, i: usize, val: upb_MessageValue); + fn upb_Array_Get(arr: RawRepeatedField, i: usize) -> upb_MessageValue; + fn upb_Array_Append(arr: RawRepeatedField, val: upb_MessageValue, arena: RawArena); +} + +macro_rules! impl_repeated_primitives { + ($(($rs_type:ty, $union_field:ident, $upb_tag:expr)),*) => { + $( + impl<'msg> RepeatedField<'msg, $rs_type> { + #[allow(dead_code)] + fn new(arena: &'msg Arena) -> Self { + Self { + inner: RepeatedFieldInner { + raw: unsafe { upb_Array_New(arena.raw, $upb_tag as std::ffi::c_int) }, + arena, + }, + _phantom: PhantomData, + } + } + pub fn push(&mut self, val: $rs_type) { + unsafe { upb_Array_Append( + self.inner.raw, + upb_MessageValue { $union_field: val }, + self.inner.arena.raw(), + ) } + } + pub fn get(&self, i: usize) -> Option<$rs_type> { + if i >= self.len() { + None + } else { + unsafe { Some(upb_Array_Get(self.inner.raw, i).$union_field) } + } + } + pub fn set(&self, i: usize, val: $rs_type) { + if i >= self.len() { + return; + } + unsafe { upb_Array_Set( + self.inner.raw, + i, + upb_MessageValue { $union_field: val }, + ) } + } + } + )* + } +} + +impl_repeated_primitives!( + (bool, bool_val, UpbCType::Bool), + (f32, float_val, UpbCType::Float), + (f64, double_val, UpbCType::Double), + (i32, int32_val, UpbCType::Int32), + (u32, uint32_val, UpbCType::UInt32), + (i64, int64_val, UpbCType::Int64), + (u64, uint64_val, UpbCType::UInt64) +); + +/// Returns a static thread-local empty RepeatedFieldInner for use in a +/// RepeatedView. +/// +/// # Safety +/// TODO: Split RepeatedFieldInner into mut and const variants to +/// enforce safety. The returned array must never be mutated. +pub unsafe fn empty_array() -> RepeatedFieldInner<'static> { + // TODO: Consider creating empty array in C. + fn new_repeated_field_inner() -> RepeatedFieldInner<'static> { + let arena = Box::leak::<'static>(Box::new(Arena::new())); + // Provide `i32` as a placeholder type. + RepeatedField::<'static, i32>::new(arena).inner + } + thread_local! { + static REPEATED_FIELD: RepeatedFieldInner<'static> = new_repeated_field_inner(); + } + + REPEATED_FIELD.with(|inner| *inner) +} + #[cfg(test)] mod tests { use super::*; @@ -309,4 +452,35 @@ mod tests { }; assert_eq!(&*serialized_data, b"Hello world"); } + + #[test] + fn i32_array() { + let mut arena = Arena::new(); + let mut arr = RepeatedField::::new(&arena); + assert_eq!(arr.len(), 0); + arr.push(1); + assert_eq!(arr.get(0), Some(1)); + assert_eq!(arr.len(), 1); + arr.set(0, 3); + assert_eq!(arr.get(0), Some(3)); + for i in 0..2048 { + arr.push(i); + assert_eq!(arr.get(arr.len() - 1), Some(i)); + } + } + #[test] + fn u32_array() { + let mut arena = Arena::new(); + let mut arr = RepeatedField::::new(&mut arena); + assert_eq!(arr.len(), 0); + arr.push(1); + assert_eq!(arr.get(0), Some(1)); + assert_eq!(arr.len(), 1); + arr.set(0, 3); + assert_eq!(arr.get(0), Some(3)); + for i in 0..2048 { + arr.push(i); + assert_eq!(arr.get(arr.len() - 1), Some(i)); + } + } } diff --git a/rust/upb_kernel/BUILD b/rust/upb_kernel/BUILD index dc65231433..b06f182117 100644 --- a/rust/upb_kernel/BUILD +++ b/rust/upb_kernel/BUILD @@ -8,6 +8,7 @@ cc_library( "//rust:__subpackages__", ], deps = [ + "//upb:collections", "//upb:mem", ], ) diff --git a/rust/upb_kernel/upb_api.c b/rust/upb_kernel/upb_api.c index 985749da23..a30b4dd9bb 100644 --- a/rust/upb_kernel/upb_api.c +++ b/rust/upb_kernel/upb_api.c @@ -8,4 +8,5 @@ #define UPB_BUILD_API -#include "upb/mem/arena.h" // IWYU pragma: keep +#include "upb/collections/array.h" // IWYU pragma: keep +#include "upb/mem/arena.h" // IWYU pragma: keep diff --git a/src/google/protobuf/compiler/rust/BUILD.bazel b/src/google/protobuf/compiler/rust/BUILD.bazel index f404beae97..1429b9bff1 100644 --- a/src/google/protobuf/compiler/rust/BUILD.bazel +++ b/src/google/protobuf/compiler/rust/BUILD.bazel @@ -51,6 +51,7 @@ cc_library( name = "accessors", srcs = [ "accessors/accessors.cc", + "accessors/repeated_scalar.cc", "accessors/singular_message.cc", "accessors/singular_scalar.cc", "accessors/singular_string.cc", diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h index e3a453493c..3bf1dca06d 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h +++ b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h @@ -86,6 +86,14 @@ class SingularMessage final : public AccessorGenerator { void InThunkCc(Context field) const override; }; +class RepeatedScalar final : public AccessorGenerator { + public: + ~RepeatedScalar() override = default; + void InMsgImpl(Context field) const override; + void InExternC(Context field) const override; + void InThunkCc(Context field) const override; +}; + class UnsupportedField final : public AccessorGenerator { public: ~UnsupportedField() override = default; diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.cc b/src/google/protobuf/compiler/rust/accessors/accessors.cc index 82b38cc642..fa5c876bf9 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessors.cc +++ b/src/google/protobuf/compiler/rust/accessors/accessors.cc @@ -29,10 +29,6 @@ std::unique_ptr AccessorGeneratorFor( return std::make_unique(); } - if (desc.is_repeated()) { - return std::make_unique(); - } - switch (desc.type()) { case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT64: @@ -47,11 +43,20 @@ std::unique_ptr AccessorGeneratorFor( case FieldDescriptor::TYPE_FLOAT: case FieldDescriptor::TYPE_DOUBLE: case FieldDescriptor::TYPE_BOOL: + if (desc.is_repeated()) { + return std::make_unique(); + } return std::make_unique(); case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_STRING: + if (desc.is_repeated()) { + return std::make_unique(); + } return std::make_unique(); case FieldDescriptor::TYPE_MESSAGE: + if (desc.is_repeated()) { + return std::make_unique(); + } return std::make_unique(); default: diff --git a/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc new file mode 100644 index 0000000000..8f0a7625b3 --- /dev/null +++ b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc @@ -0,0 +1,156 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2023 Google LLC. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#include "absl/strings/string_view.h" +#include "google/protobuf/compiler/cpp/helpers.h" +#include "google/protobuf/compiler/rust/accessors/accessor_generator.h" +#include "google/protobuf/compiler/rust/context.h" +#include "google/protobuf/compiler/rust/naming.h" +#include "google/protobuf/descriptor.h" + +namespace google { +namespace protobuf { +namespace compiler { +namespace rust { + +void RepeatedScalar::InMsgImpl(Context field) const { + field.Emit({{"field", field.desc().name()}, + {"Scalar", PrimitiveRsTypeName(field.desc())}, + {"getter_thunk", Thunk(field, "get")}, + {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"getter", + [&] { + if (field.is_upb()) { + field.Emit({}, R"rs( + pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> { + let inner = unsafe { + $getter_thunk$( + self.inner.msg, + /* optional size pointer */ std::ptr::null(), + ) } + .map_or_else(|| unsafe {$pbr$::empty_array()}, |raw| { + $pbr$::RepeatedFieldInner{ raw, arena: &self.inner.arena } + }); + $pb$::RepeatedView::from_inner($pbi$::Private, inner) + } + )rs"); + } else { + field.Emit({}, R"rs( + pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> { + $pb$::RepeatedView::from_inner( + $pbi$::Private, + $pbr$::RepeatedFieldInner{ + raw: unsafe { $getter_thunk$(self.inner.msg) }, + _phantom: std::marker::PhantomData, + }, + ) + } + )rs"); + } + }}, + {"clearer_thunk", Thunk(field, "clear")}, + {"field_mutator_getter", + [&] { + if (field.is_upb()) { + field.Emit({}, R"rs( + pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> { + $pb$::RepeatedMut::from_inner( + $pbi$::Private, + $pbr$::RepeatedFieldInner{ + raw: unsafe { $getter_mut_thunk$( + self.inner.msg, + /* optional size pointer */ std::ptr::null(), + self.inner.arena.raw(), + ) }, + arena: &self.inner.arena, + }, + ) + } + )rs"); + } else { + field.Emit({}, R"rs( + pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> { + $pb$::RepeatedMut::from_inner( + $pbi$::Private, + $pbr$::RepeatedFieldInner{ + raw: unsafe { $getter_mut_thunk$(self.inner.msg)}, + _phantom: std::marker::PhantomData, + }, + ) + } + )rs"); + } + }}}, + R"rs( + $getter$ + $field_mutator_getter$ + )rs"); +} + +void RepeatedScalar::InExternC(Context field) const { + field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())}, + {"getter_thunk", Thunk(field, "get")}, + {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"getter", + [&] { + if (field.is_upb()) { + field.Emit(R"rs( + fn $getter_mut_thunk$( + raw_msg: $pbi$::RawMessage, + size: *const usize, + arena: $pbi$::RawArena, + ) -> $pbi$::RawRepeatedField; + // Returns `None` when returned array pointer is NULL. + fn $getter_thunk$( + raw_msg: $pbi$::RawMessage, + size: *const usize, + ) -> Option<$pbi$::RawRepeatedField>; + )rs"); + } else { + field.Emit(R"rs( + fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField; + fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField; + )rs"); + } + }}, + {"clearer_thunk", Thunk(field, "clear")}}, + R"rs( + fn $clearer_thunk$(raw_msg: $pbi$::RawMessage); + $getter$ + )rs"); +} + +void RepeatedScalar::InThunkCc(Context field) const { + field.Emit({{"field", cpp::FieldName(&field.desc())}, + {"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())}, + {"QualifiedMsg", + cpp::QualifiedClassName(field.desc().containing_type())}, + {"clearer_thunk", Thunk(field, "clear")}, + {"getter_thunk", Thunk(field, "get")}, + {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"impls", + [&] { + field.Emit( + R"cc( + void $clearer_thunk$($QualifiedMsg$* msg) { + msg->clear_$field$(); + } + google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) { + return msg->mutable_$field$(); + } + const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) { + return msg.$field$(); + } + )cc"); + }}}, + "$impls$"); +} + +} // namespace rust +} // namespace compiler +} // namespace protobuf +} // namespace google diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc index eb0b30b905..bb1c35a3e2 100644 --- a/src/google/protobuf/compiler/rust/naming.cc +++ b/src/google/protobuf/compiler/rust/naming.cc @@ -64,22 +64,28 @@ std::string GetHeaderFile(Context file) { namespace { template -std::string Thunk(Context field, absl::string_view op) { +std::string FieldPrefix(Context field) { // NOTE: When field.is_upb(), this functions outputs must match the symbols // that the upbc plugin generates exactly. Failure to do so correctly results // in a link-time failure. absl::string_view prefix = field.is_cpp() ? "__rust_proto_thunk__" : ""; - std::string thunk = + std::string thunk_prefix = absl::StrCat(prefix, GetUnderscoreDelimitedFullName( field.WithDesc(field.desc().containing_type()))); + return thunk_prefix; +} + +template +std::string Thunk(Context field, absl::string_view op) { + std::string thunk = FieldPrefix(field); absl::string_view format; if (field.is_upb() && op == "get") { // upb getter is simply the field name (no "get" in the name). format = "_$1"; - } else if (field.is_upb() && op == "case") { - // upb oneof case function is x_case compared to has/set/clear which are in - // the other order e.g. clear_x. + } else if (field.is_upb() && (op == "case")) { + // some upb functions are in the order x_op compared to has/set/clear which + // are in the other order e.g. op_x. format = "_$1_$0"; } else { format = "_$0_$1"; @@ -89,9 +95,32 @@ std::string Thunk(Context field, absl::string_view op) { return thunk; } +std::string ThunkRepeated(Context field, + absl::string_view op) { + if (!field.is_upb()) { + return Thunk(field, op); + } + + std::string thunk = absl::StrCat("_", FieldPrefix(field)); + absl::string_view format; + if (op == "get") { + format = "_$1_upb_array"; + } else if (op == "get_mut") { + format = "_$1_mutable_upb_array"; + } else { + return Thunk(field, op); + } + + absl::SubstituteAndAppend(&thunk, format, op, field.desc().name()); + return thunk; +} + } // namespace std::string Thunk(Context field, absl::string_view op) { + if (field.desc().is_repeated()) { + return ThunkRepeated(field, op); + } return Thunk(field, op); }