Implement `IntoProxied` for repeated field setters

We modify set_<repeated_field> to accept the IntoProxied type as the value and move the value (avoid copying) whenever possible.

For UPB:
 - We fuse the arena of Repeated<T> with the parent message arena.
 - We use upb_Message_SetBaseField to set the upb_Array contained in the Repeated<T>.

For C++:
 - We generate an additional setter thunk that moves the value.
 - The move assignment operator of RepeatedField/RepeatedPtrField is specialized. In order to adhere to the layering check we need to add '#include' statements for all .proto imports to the generated thunks.pb.cc.

PiperOrigin-RevId: 631010333
pull/16686/head
Jakob Buchgraber 2024-05-06 05:16:05 -07:00 committed by Copybara-Service
parent 2f6e705595
commit b6e0a48b02
12 changed files with 166 additions and 20 deletions

View File

@ -324,13 +324,17 @@ def _rust_proto_aspect_common(target, ctx, is_upb):
if is_upb:
thunks_cc_info = target[UpbWrappedCcInfo].cc_info_with_thunks
else:
dep_cc_infos = []
for dep in proto_deps:
dep_cc_infos.append(dep[CcInfo])
thunks_cc_info = cc_common.merge_cc_infos(cc_infos = [_compile_cc(
feature_configuration = feature_configuration,
src = thunk,
ctx = ctx,
attr = attr,
cc_toolchain = cc_toolchain,
cc_infos = [target[CcInfo], ctx.attr._cpp_thunks_deps[CcInfo]],
cc_infos = [target[CcInfo], ctx.attr._cpp_thunks_deps[CcInfo]] + dep_cc_infos,
) for thunk in thunks])
runtime = proto_lang_toolchain.runtime

View File

@ -301,6 +301,10 @@ impl InnerRepeated {
pub fn as_mut(&mut self) -> InnerRepeatedMut<'_> {
InnerRepeatedMut::new(Private, self.raw)
}
pub fn raw(&self) -> RawRepeatedField {
self.raw
}
}
/// The raw type-erased pointer version of `RepeatedMut`.

View File

@ -15,7 +15,7 @@ use std::iter::FusedIterator;
use std::marker::PhantomData;
use crate::{
Mut, MutProxied, MutProxy, Proxied, View, ViewProxy,
IntoProxied, Mut, MutProxied, MutProxy, Proxied, View, ViewProxy,
__internal::Private,
__runtime::{InnerRepeated, InnerRepeatedMut, RawRepeatedField},
};
@ -208,6 +208,49 @@ where
}
}
impl<T> Repeated<T>
where
T: ?Sized + ProxiedInRepeated,
{
pub fn as_view(&self) -> View<Repeated<T>> {
RepeatedView { raw: self.inner.raw(), _phantom: PhantomData }
}
#[doc(hidden)]
pub fn inner(&self, _private: Private) -> &InnerRepeated {
&self.inner
}
}
impl<T> IntoProxied<Repeated<T>> for Repeated<T>
where
T: ?Sized + ProxiedInRepeated,
{
fn into(self, _private: Private) -> Repeated<T> {
self
}
}
impl<'msg, T> IntoProxied<Repeated<T>> for RepeatedView<'msg, T>
where
T: 'msg + ?Sized + ProxiedInRepeated,
{
fn into(self, _private: Private) -> Repeated<T> {
let mut repeated: Repeated<T> = Repeated::new();
T::repeated_copy_from(self, repeated.as_mut());
repeated
}
}
impl<'msg, T> IntoProxied<Repeated<T>> for RepeatedMut<'msg, T>
where
T: 'msg + ?Sized + ProxiedInRepeated,
{
fn into(self, _private: Private) -> Repeated<T> {
IntoProxied::into(self.as_view(), _private)
}
}
/// Types that can appear in a `Repeated<T>`.
///
/// This trait is implemented by generated code to communicate how the proxied
@ -275,7 +318,7 @@ impl<'msg, T: ?Sized> Debug for RepeatedIter<'msg, T> {
/// Users will generally write [`View<Repeated<T>>`](RepeatedView) or
/// [`Mut<Repeated<T>>`](RepeatedMut) to access the repeated elements
pub struct Repeated<T: ?Sized + ProxiedInRepeated> {
inner: InnerRepeated,
pub(crate) inner: InnerRepeated,
_phantom: PhantomData<T>,
}

View File

@ -155,6 +155,14 @@ impl InnerRepeated {
pub fn as_mut(&mut self) -> InnerRepeatedMut<'_> {
InnerRepeatedMut::new(Private, self.raw, &self.arena)
}
pub fn raw(&self) -> RawRepeatedField {
self.raw
}
pub fn arena(&self) -> &Arena {
&self.arena
}
}
/// The raw type-erased pointer version of `RepeatedMut`.

View File

@ -46,5 +46,6 @@ cc_library(
"//upb:mem",
"//upb:message",
"//upb:message_copy",
"//upb/mini_table",
],
)

View File

@ -22,14 +22,18 @@ pub use map::{
mod message;
pub use message::{
upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, upb_Message_New, RawMessage,
upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, upb_Message_New,
upb_Message_SetBaseField, RawMessage,
};
mod message_value;
pub use message_value::{upb_MessageValue, upb_MutableMessageValue};
mod mini_table;
pub use mini_table::{upb_MiniTable, RawMiniTable};
pub use mini_table::{
upb_MiniTable, upb_MiniTableField, upb_MiniTable_FindFieldByNumber, RawMiniTable,
RawMiniTableField,
};
mod opaque_pointee;

View File

@ -1,5 +1,5 @@
use crate::opaque_pointee::opaque_pointee;
use crate::{upb_MiniTable, RawArena};
use crate::{upb_MiniTable, upb_MiniTableField, RawArena};
use std::ptr::NonNull;
opaque_pointee!(upb_Message);
@ -22,4 +22,10 @@ extern "C" {
mini_table: *const upb_MiniTable,
arena: RawArena,
) -> Option<RawMessage>;
pub fn upb_Message_SetBaseField(
m: RawMessage,
mini_table: *const upb_MiniTableField,
val: *const std::ffi::c_void,
);
}

View File

@ -3,3 +3,13 @@ use std::ptr::NonNull;
opaque_pointee!(upb_MiniTable);
pub type RawMiniTable = NonNull<upb_MiniTable>;
opaque_pointee!(upb_MiniTableField);
pub type RawMiniTableField = NonNull<upb_MiniTableField>;
extern "C" {
pub fn upb_MiniTable_FindFieldByNumber(
m: *const upb_MiniTable,
number: u32,
) -> *const upb_MiniTableField;
}

View File

@ -10,9 +10,11 @@
#define UPB_BUILD_API
#include "upb/mem/arena.h" // IWYU pragma: keep
#include "upb/message/array.h" // IWYU pragma: keep
#include "upb/message/copy.h" // IWYU pragma: keep
#include "upb/message/map.h" // IWYU pragma: keep
#include "upb/mem/arena.h" // IWYU pragma: keep
#include "upb/message/array.h" // IWYU pragma: keep
#include "upb/message/copy.h" // IWYU pragma: keep
#include "upb/message/map.h" // IWYU pragma: keep
#include "upb/message/accessors.h" // IWYU pragma: keep
#include "upb/mini_table/message.h" // IWYU pragma: keep
const size_t __rust_proto_kUpb_Map_Begin = kUpb_Map_Begin;

View File

@ -104,17 +104,47 @@ void RepeatedField::InMsgImpl(Context& ctx, const FieldDescriptor& field,
)rs");
}
}},
{"move_setter_thunk", ThunkName(ctx, field, "move_set")},
{"setter",
[&] {
if (accessor_case == AccessorCase::VIEW) {
return;
}
ctx.Emit({}, R"rs(
pub fn set_$raw_field_name$(&mut self, src: $pb$::RepeatedView<'_, $RsType$>) {
// TODO: Implement IntoProxied and avoid copying.
self.$field$_mut().copy_from(src);
}
)rs");
if (ctx.is_upb()) {
ctx.Emit({{"field_number", field.number()}}, R"rs(
pub fn set_$raw_field_name$(&mut self, src: impl $pb$::IntoProxied<$pb$::Repeated<$RsType$>>) {
let minitable_field = unsafe {
$pbr$::upb_MiniTable_FindFieldByNumber(
Self::raw_minitable($pbi$::Private),
$field_number$
)
};
let val = src.into($pbi$::Private);
let inner = val.inner($pbi$::Private);
self.arena().fuse(inner.arena());
unsafe {
let value_ptr: *const *const std::ffi::c_void =
&(inner.raw().as_ptr() as *const std::ffi::c_void);
$pbr$::upb_Message_SetBaseField(self.raw_msg(),
minitable_field,
value_ptr as *const std::ffi::c_void);
}
}
)rs");
} else {
ctx.Emit({}, R"rs(
pub fn set_$raw_field_name$(&mut self, src: impl $pb$::IntoProxied<$pb$::Repeated<$RsType$>>) {
// Prevent the memory from being deallocated. The setter
// transfers ownership of the memory to the parent message.
let val = std::mem::ManuallyDrop::new(src.into($pbi$::Private));
unsafe {
$move_setter_thunk$(self.raw_msg(),
val.inner($pbi$::Private).raw());
}
}
)rs");
}
}},
},
R"rs(
@ -128,6 +158,7 @@ void RepeatedField::InExternC(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"getter_thunk", ThunkName(ctx, field, "get")},
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},
{"move_setter_thunk", ThunkName(ctx, field, "move_set")},
{"getter",
[&] {
if (ctx.is_upb()) {
@ -147,6 +178,7 @@ void RepeatedField::InExternC(Context& ctx,
ctx.Emit(R"rs(
fn $getter_mut_thunk$(raw_msg: $pbr$::RawMessage) -> $pbr$::RawRepeatedField;
fn $getter_thunk$(raw_msg: $pbr$::RawMessage) -> $pbr$::RawRepeatedField;
fn $move_setter_thunk$(raw_msg: $pbr$::RawMessage, value: $pbr$::RawRepeatedField);
)rs");
}
}},
@ -199,6 +231,7 @@ void RepeatedField::InThunkCc(Context& ctx,
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},
{"repeated_copy_from_thunk",
ThunkName(ctx, field, "repeated_copy_from")},
{"move_setter_thunk", ThunkName(ctx, field, "move_set")},
{"impls",
[&] {
ctx.Emit(
@ -213,6 +246,11 @@ void RepeatedField::InThunkCc(Context& ctx,
const $QualifiedMsg$* msg) {
return &msg->$field$();
}
void $move_setter_thunk$(
$QualifiedMsg$* msg,
$ContainerType$<$ElementType$>* value) {
*msg->mutable_$field$() = std::move(*value);
}
)cc");
}}},
"$impls$");

View File

@ -202,11 +202,23 @@ bool RustGenerator::Generate(const FileDescriptor* file,
thunks_cc.reset(generator_context->Open(GetThunkCcFile(ctx, *file)));
thunks_printer = std::make_unique<io::Printer>(thunks_cc.get());
thunks_printer->Emit({{"proto_h", GetHeaderFile(ctx, *file)}},
R"cc(
thunks_printer->Emit(
{{"proto_h", GetHeaderFile(ctx, *file)},
{"proto_deps_h",
[&] {
for (int i = 0; i < file->dependency_count(); i++) {
thunks_printer->Emit(
{{"proto_dep_h", GetHeaderFile(ctx, *file->dependency(i))}},
R"cc(
#include "$proto_dep_h$"
)cc");
}
}}},
R"cc(
#include "$proto_h$"
$proto_deps_h$
#include "google/protobuf/rust/cpp_kernel/cpp_api.h"
)cc");
)cc");
}
for (int i = 0; i < file->message_type_count(); ++i) {

View File

@ -285,6 +285,16 @@ void IntoProxiedForMessage(Context& ctx, const Descriptor& msg) {
ABSL_LOG(FATAL) << "unreachable";
}
void MessageGetMinitable(Context& ctx, const Descriptor& msg) {
if (ctx.opts().kernel == Kernel::kUpb) {
ctx.Emit({{"minitable", UpbMinitableName(msg)}}, R"rs(
pub fn raw_minitable(_private: $pbi$::Private) -> *const $pbr$::upb_MiniTable {
unsafe { $std$::ptr::addr_of!($minitable$) }
}
)rs");
}
}
void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
@ -363,7 +373,6 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) {
}
}
}
)rs");
return;
case Kernel::kUpb:
@ -821,6 +830,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
}
}},
{"into_proxied_impl", [&] { IntoProxiedForMessage(ctx, msg); }},
{"get_upb_minitable", [&] { MessageGetMinitable(ctx, msg); }},
{"repeated_impl", [&] { MessageProxiedInRepeated(ctx, msg); }},
{"map_value_impl", [&] { MessageProxiedInMapValue(ctx, msg); }},
{"unwrap_upb",
@ -976,6 +986,8 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
$pb$::ViewProxy::as_view(self).serialize()
}
$get_upb_minitable$
$raw_arena_getter_for_msgmut$
$accessor_fns_for_muts$
@ -1043,6 +1055,8 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
$Msg$Mut::new($pbi$::Private, &mut self.inner)
}
$get_upb_minitable$
$accessor_fns$
} // impl $Msg$