From 734729afc2ace9dc33e41b91a34fd16bb53594d9 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Wed, 24 Apr 2024 12:25:37 -0700 Subject: [PATCH] Create the concept of 'owned data' in upb/rust as a generalization of the upb.rs SerializedData (which is a arena + data for arbitrary types, both thin and wide ref types), use that for the wire parse/serialize path. PiperOrigin-RevId: 627814154 --- rust/upb.rs | 68 +----------- rust/upb/BUILD | 1 + rust/upb/arena.rs | 49 +++++++- rust/upb/lib.rs | 11 +- rust/upb/message.rs | 4 + rust/upb/owned_arena_box.rs | 111 +++++++++++++++++++ rust/upb/wire.rs | 66 ++++++++++- src/google/protobuf/compiler/rust/message.cc | 40 ++----- 8 files changed, 246 insertions(+), 104 deletions(-) create mode 100644 rust/upb/owned_arena_box.rs diff --git a/rust/upb.rs b/rust/upb.rs index 3b9a69ba44..e697a27bf8 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -14,9 +14,7 @@ use crate::{ }; use core::fmt::Debug; use std::alloc::Layout; -use std::fmt; use std::mem::{size_of, ManuallyDrop, MaybeUninit}; -use std::ops::Deref; use std::ptr::{self, NonNull}; use std::slice; use std::sync::OnceLock; @@ -60,55 +58,7 @@ impl ScratchSpace { } } -/// Serialized Protobuf wire format data. -/// -/// It's typically produced by `::serialize()`. -pub struct SerializedData { - data: NonNull, - len: usize, - - // The arena that owns `data`. - _arena: Arena, -} - -impl SerializedData { - /// Construct `SerializedData` from raw pointers and its owning arena. - /// - /// # Safety - /// - `arena` must be have allocated `data` - /// - `data` must be readable for `len` bytes and not mutate while this - /// struct exists - pub unsafe fn from_raw_parts(arena: Arena, data: NonNull, len: usize) -> Self { - SerializedData { _arena: arena, data, len } - } - - /// Gets a raw slice pointer. - pub fn as_ptr(&self) -> *const [u8] { - ptr::slice_from_raw_parts(self.data.as_ptr(), self.len) - } -} - -impl Deref for SerializedData { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - // SAFETY: `data` is valid for `len` bytes as promised by - // the caller of `SerializedData::from_raw_parts`. - unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len) } - } -} - -// TODO: remove after IntoProxied has been implemented for bytes. -impl AsRef<[u8]> for SerializedData { - fn as_ref(&self) -> &[u8] { - self - } -} - -impl fmt::Debug for SerializedData { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(self.deref(), f) - } -} +pub type SerializedData = upb::OwnedArenaBox<[u8]>; // TODO: Investigate replacing this with direct access to UPB bits. pub type MessagePresentMutData<'msg, T> = crate::vtable::RawVTableOptionalMutatorData<'msg, T>; @@ -812,22 +762,6 @@ mod tests { use super::*; use googletest::prelude::*; - #[test] - fn test_serialized_data_roundtrip() { - let arena = Arena::new(); - let original_data = b"Hello world"; - let len = original_data.len(); - - let serialized_data = unsafe { - SerializedData::from_raw_parts( - arena, - NonNull::new(original_data as *const _ as *mut _).unwrap(), - len, - ) - }; - assert_that!(&*serialized_data, eq(b"Hello world")); - } - #[test] fn assert_c_type_sizes() { // TODO: add these same asserts in C++. diff --git a/rust/upb/BUILD b/rust/upb/BUILD index 894d4950e0..13cea4d7f5 100644 --- a/rust/upb/BUILD +++ b/rust/upb/BUILD @@ -23,6 +23,7 @@ rust_library( "message_value.rs", "mini_table.rs", "opaque_pointee.rs", + "owned_arena_box.rs", "string_view.rs", "wire.rs", ], diff --git a/rust/upb/arena.rs b/rust/upb/arena.rs index 8610d4de56..281ea2f83c 100644 --- a/rust/upb/arena.rs +++ b/rust/upb/arena.rs @@ -3,7 +3,7 @@ use std::alloc::{self, Layout}; use std::cell::UnsafeCell; use std::marker::PhantomData; use std::mem::{align_of, MaybeUninit}; -use std::ptr::NonNull; +use std::ptr::{self, NonNull}; use std::slice; opaque_pointee!(upb_Arena); @@ -95,6 +95,53 @@ impl Arena { // `UPB_MALLOC_ALIGN` boundary. unsafe { slice::from_raw_parts_mut(ptr.cast(), layout.size()) } } + + /// Same as alloc() but panics if `layout.align() > UPB_MALLOC_ALIGN`. + #[allow(clippy::mut_from_ref)] + #[inline] + pub fn checked_alloc(&self, layout: Layout) -> &mut [MaybeUninit] { + assert!(layout.align() <= UPB_MALLOC_ALIGN); + // SAFETY: layout.align() <= UPB_MALLOC_ALIGN asserted. + unsafe { self.alloc(layout) } + } + + /// Copies the T into this arena and returns a pointer to the T data inside + /// the arena. + pub fn copy_in<'a, T: Copy>(&'a self, data: &T) -> &'a T { + let layout = Layout::for_value(data); + let alloc = self.checked_alloc(layout); + + // SAFETY: + // - alloc is valid for `layout.len()` bytes and is the uninit bytes are written + // to not read from until written. + // - T is copy so copying the bytes of the value is sound. + unsafe { + let alloc = alloc.as_mut_ptr().cast::>(); + // let data = (data as *const T).cast::>(); + (*alloc).write(*data) + } + } + + pub fn copy_str_in<'a>(&'a self, s: &str) -> &'a str { + let copied_bytes = self.copy_slice_in(s.as_bytes()); + // SAFETY: `copied_bytes` has same contents as `s` and so must meet &str + // criteria. + unsafe { std::str::from_utf8_unchecked(copied_bytes) } + } + + pub fn copy_slice_in<'a, T: Copy>(&'a self, data: &[T]) -> &'a [T] { + let layout = Layout::for_value(data); + let alloc: *mut T = self.checked_alloc(layout).as_mut_ptr().cast(); + + // SAFETY: + // - uninit_alloc is valid for `layout.len()` bytes and is the uninit bytes are + // written to not read from until written. + // - T is copy so copying the bytes of the values is sound. + unsafe { + ptr::copy_nonoverlapping(data.as_ptr(), alloc, data.len()); + slice::from_raw_parts_mut(alloc, data.len()) + } + } } impl Default for Arena { diff --git a/rust/upb/lib.rs b/rust/upb/lib.rs index f557d1b024..040d98d736 100644 --- a/rust/upb/lib.rs +++ b/rust/upb/lib.rs @@ -21,7 +21,9 @@ pub use map::{ }; mod message; -pub use message::{upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, RawMessage}; +pub use message::{ + upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, upb_Message_New, RawMessage, +}; mod message_value; pub use message_value::{upb_MessageValue, upb_MutableMessageValue}; @@ -31,8 +33,11 @@ pub use mini_table::{upb_MiniTable, RawMiniTable}; mod opaque_pointee; +mod owned_arena_box; +pub use owned_arena_box::OwnedArenaBox; + mod string_view; pub use string_view::StringView; -mod wire; -pub use wire::{upb_Decode, upb_Encode, DecodeStatus, EncodeStatus}; +pub mod wire; +pub use wire::{upb_Decode, DecodeStatus, EncodeStatus}; diff --git a/rust/upb/message.rs b/rust/upb/message.rs index 2a9fe91e30..fd831d7b6d 100644 --- a/rust/upb/message.rs +++ b/rust/upb/message.rs @@ -6,6 +6,10 @@ opaque_pointee!(upb_Message); pub type RawMessage = NonNull; extern "C" { + /// SAFETY: No constraints. + pub fn upb_Message_New(mini_table: *const upb_MiniTable, arena: RawArena) + -> Option; + pub fn upb_Message_DeepCopy( dst: RawMessage, src: RawMessage, diff --git a/rust/upb/owned_arena_box.rs b/rust/upb/owned_arena_box.rs new file mode 100644 index 0000000000..9f10accc49 --- /dev/null +++ b/rust/upb/owned_arena_box.rs @@ -0,0 +1,111 @@ +use crate::Arena; +use std::fmt::{self, Debug}; +use std::ops::{Deref, DerefMut}; +use std::ptr::NonNull; + +/// An 'owned' T, similar to a Box where the T is data +/// held in a upb Arena. By holding the data pointer and a corresponding arena +/// together the data liveness is be maintained. +/// +/// This struct is conceptually self-referential, where `data` points at memory +/// inside `arena`. This avoids typical concerns of self-referential data +/// structures because `arena` modifications (other than drop) will never +/// invalidate `data`, and `data` and `arena` are both behind indirections which +/// avoids any concern with std::mem::swap. +pub struct OwnedArenaBox { + data: NonNull, + arena: Arena, +} + +impl OwnedArenaBox { + /// Construct `OwnedArenaBox` from raw pointers and its owning arena. + /// + /// # Safety + /// - `data` must satisfy the safety constraints of pointer::as_mut::<'a>() + /// where 'a is the passed arena's lifetime (`data` should be valid and + /// not mutated while this struct is live). + /// - `data` should be a pointer into a block from a previous allocation on + /// `arena`, or to another arena fused to it, or should be pointing at + /// 'static data (and if it is pointing at any struct like upb_Message, + /// all data transitively reachable should similarly be kept live by + /// `arena` or be 'static). + pub unsafe fn new(data: NonNull, arena: Arena) -> Self { + OwnedArenaBox { arena, data } + } + + pub fn data(&self) -> *const T { + self.data.as_ptr() + } + + pub fn into_parts(self) -> (NonNull, Arena) { + (self.data, self.arena) + } +} + +impl Deref for OwnedArenaBox { + type Target = T; + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl DerefMut for OwnedArenaBox { + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut() + } +} + +impl AsRef for OwnedArenaBox { + fn as_ref(&self) -> &T { + // SAFETY: + // - `data` is valid under the conditions set on ::new(). + unsafe { self.data.as_ref() } + } +} + +impl AsMut for OwnedArenaBox { + fn as_mut(&mut self) -> &mut T { + // SAFETY: + // - `data` is valid under the conditions set on ::new(). + unsafe { self.data.as_mut() } + } +} + +impl Debug for OwnedArenaBox { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("OwnedArenaBox").field(self.deref()).finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str; + + #[test] + fn test_byte_slice_pointer_roundtrip() { + let arena = Arena::new(); + let original_data: &'static [u8] = b"Hello world"; + let owned_data = unsafe { OwnedArenaBox::new(original_data.into(), arena) }; + assert_eq!(&*owned_data, b"Hello world"); + } + + #[test] + fn test_alloc_str_roundtrip() { + let arena = Arena::new(); + let s: &str = "Hello"; + let arena_alloc_str: NonNull = arena.copy_str_in(s).into(); + let owned_data = unsafe { OwnedArenaBox::new(arena_alloc_str, arena) }; + assert_eq!(&*owned_data, s); + } + + #[test] + fn test_sized_type_roundtrip() { + let arena = Arena::new(); + let arena_alloc_u32: NonNull = arena.copy_in(&7u32).into(); + let mut owned_data = unsafe { OwnedArenaBox::new(arena_alloc_u32, arena) }; + assert_eq!(*owned_data, 7); + *owned_data = 8; + assert_eq!(*owned_data, 8); + } +} diff --git a/rust/upb/wire.rs b/rust/upb/wire.rs index 2b68145cc2..aa335cc020 100644 --- a/rust/upb/wire.rs +++ b/rust/upb/wire.rs @@ -1,8 +1,9 @@ -use crate::{upb_ExtensionRegistry, upb_MiniTable, RawArena, RawMessage}; +use crate::{upb_ExtensionRegistry, upb_MiniTable, Arena, OwnedArenaBox, RawArena, RawMessage}; +use std::ptr::NonNull; // LINT.IfChange(encode_status) #[repr(C)] -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Debug)] pub enum EncodeStatus { Ok = 0, OutOfMemory = 1, @@ -13,7 +14,7 @@ pub enum EncodeStatus { // LINT.IfChange(decode_status) #[repr(C)] -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Debug)] pub enum DecodeStatus { Ok = 0, Malformed = 1, @@ -25,7 +26,62 @@ pub enum DecodeStatus { } // LINT.ThenChange() +/// If Err, then EncodeStatus != Ok. +/// +/// SAFETY: +/// - `msg` must be associated with `mini_table`. +pub unsafe fn encode( + msg: RawMessage, + mini_table: *const upb_MiniTable, +) -> Result, EncodeStatus> { + let arena = Arena::new(); + let mut buf: *mut u8 = std::ptr::null_mut(); + let mut len = 0usize; + + // SAFETY: + // - `mini_table` is the one associated with `msg`. + // - `buf` and `buf_size` are legally writable. + let status = upb_Encode(msg, mini_table, 0, arena.raw(), &mut buf, &mut len); + + if status == EncodeStatus::Ok { + assert!(!buf.is_null()); // EncodeStatus Ok should never return NULL data, even for len=0. + // SAFETY: upb guarantees that `buf` is valid to read for `len`. + let slice = NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(buf, len)); + Ok(OwnedArenaBox::new(slice, arena)) + } else { + Err(status) + } +} + +/// Decodes into the provided message (merge semantics). If Err, then +/// DecodeStatus != Ok. +/// +/// SAFETY: +/// - `msg` must be mutable. +/// - `msg` must be associated with `mini_table`. +pub unsafe fn decode( + buf: &[u8], + msg: RawMessage, + mini_table: *const upb_MiniTable, + arena: &Arena, +) -> Result<(), DecodeStatus> { + let len = buf.len(); + let buf = buf.as_ptr(); + // SAFETY: + // - `mini_table` is the one associated with `msg` + // - `buf` is legally readable for at least `buf_size` bytes. + // - `extreg` is null. + let status = upb_Decode(buf, len, msg, mini_table, std::ptr::null(), 0, arena.raw()); + match status { + DecodeStatus::Ok => Ok(()), + _ => Err(status), + } +} + extern "C" { + // SAFETY: + // - `mini_table` is the one associated with `msg` + // - `buf` and `buf_size` are legally writable. pub fn upb_Encode( msg: RawMessage, mini_table: *const upb_MiniTable, @@ -35,6 +91,10 @@ extern "C" { buf_size: *mut usize, ) -> EncodeStatus; + // SAFETY: + // - `mini_table` is the one associated with `msg` + // - `buf` is legally readable for at least `buf_size` bytes. + // - `extreg` is either null or points at a valid upb_ExtensionRegistry. pub fn upb_Decode( buf: *const u8, buf_size: usize, diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 5700ebb458..5ccf2541ac 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -68,35 +68,17 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) { case Kernel::kUpb: ctx.Emit({{"minitable", UpbMinitableName(msg)}}, R"rs( - let arena = $pbr$::Arena::new(); // SAFETY: $minitable$ is a static of a const object. let mini_table = unsafe { $std$::ptr::addr_of!($minitable$) }; - let options = 0; - let mut buf: *mut u8 = std::ptr::null_mut(); - let mut len = 0; - - // SAFETY: `mini_table` is the corresponding one that was used to - // construct `self.raw_msg()`. - let status = unsafe { - $pbr$::upb_Encode(self.raw_msg(), mini_table, options, arena.raw(), - &mut buf, &mut len) + // SAFETY: $minitable$ is the one associated with raw_msg(). + let encoded = unsafe { + $pbr$::wire::encode(self.raw_msg(), mini_table) }; //~ TODO: Currently serialize() on the Rust API is an //~ infallible fn, so if upb signals an error here we can only panic. - assert!(status == $pbr$::EncodeStatus::Ok); - let data = if len == 0 { - std::ptr::NonNull::dangling() - } else { - std::ptr::NonNull::new(buf).unwrap() - }; - - // SAFETY: - // - `arena` allocated `data`. - // - `data` is valid for reads up to `len` and will not be mutated. - unsafe { - $pbr$::SerializedData::from_raw_parts(arena, data, len) - } + let serialized = encoded.expect("serialize is not allowed to fail"); + serialized )rs"); return; } @@ -131,27 +113,25 @@ void MessageClearAndParse(Context& ctx, const Descriptor& msg) { let mut msg = Self::new(); // SAFETY: $minitable$ is a static of a const object. let mini_table = unsafe { $std$::ptr::addr_of!($minitable$) }; - let ext_reg = std::ptr::null(); - let options = 0; // SAFETY: // - `data.as_ptr()` is valid to read for `data.len()` // - `mini_table` is the one used to construct `msg.raw_msg()` // - `msg.arena().raw()` is held for the same lifetime as `msg`. let status = unsafe { - $pbr$::upb_Decode( - data.as_ptr(), data.len(), msg.raw_msg(), - mini_table, ext_reg, options, msg.arena().raw()) + $pbr$::wire::decode( + data, msg.raw_msg(), + mini_table, msg.arena()) }; match status { - $pbr$::DecodeStatus::Ok => { + Ok(_) => { //~ This swap causes the old self.inner.arena to be moved into `msg` //~ which we immediately drop, which will release any previous //~ message that was held here. std::mem::swap(self, &mut msg); Ok(()) } - _ => Err($pb$::ParseError) + Err(_) => Err($pb$::ParseError) } )rs"); return;