Move casting functions to MessageLite and use ClassData as the uniqueness

instead of Reflection. This allows using these functions instead of
`dynamic_cast` for all generated types including LITE.

PiperOrigin-RevId: 631927387
pull/16803/head
Protobuf Team Bot 2024-05-08 14:33:40 -07:00 committed by Copybara-Service
parent 7febb4c48f
commit f21cf23fc4
4 changed files with 218 additions and 97 deletions

View File

@ -27,6 +27,7 @@
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map_lite_test_util.h"
#include "google/protobuf/map_lite_unittest.pb.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/parse_context.h"
#include "google/protobuf/test_util_lite.h"
#include "google/protobuf/unittest_lite.pb.h"
@ -1325,6 +1326,92 @@ TEST(LiteBasicTest, CodedInputStreamRollback) {
}
}
// Two arbitary types
using CastType1 = protobuf_unittest::TestAllTypesLite;
using CastType2 = protobuf_unittest::TestPackedTypesLite;
TEST(LiteTest, DynamicCastToGenerated) {
CastType1 test_type_1;
MessageLite* test_type_1_pointer = &test_type_1;
EXPECT_EQ(&test_type_1,
DynamicCastToGenerated<CastType1>(test_type_1_pointer));
EXPECT_EQ(nullptr, DynamicCastToGenerated<CastType2>(test_type_1_pointer));
const MessageLite* test_type_1_pointer_const = &test_type_1;
EXPECT_EQ(&test_type_1,
DynamicCastToGenerated<const CastType1>(test_type_1_pointer_const));
EXPECT_EQ(nullptr,
DynamicCastToGenerated<const CastType2>(test_type_1_pointer_const));
MessageLite* test_type_1_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr,
DynamicCastToGenerated<CastType1>(test_type_1_pointer_nullptr));
MessageLite& test_type_1_pointer_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DynamicCastToGenerated<CastType1>(test_type_1_pointer_ref));
const MessageLite& test_type_1_pointer_const_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DynamicCastToGenerated<CastType1>(test_type_1_pointer_const_ref));
}
#if GTEST_HAS_DEATH_TEST
TEST(LiteTest, DynamicCastToGeneratedInvalidReferenceType) {
CastType1 test_type_1;
const MessageLite& test_type_1_pointer_const_ref = test_type_1;
ASSERT_DEATH(DynamicCastToGenerated<CastType2>(test_type_1_pointer_const_ref),
"Cannot downcast " + test_type_1.GetTypeName() + " to " +
CastType2::default_instance().GetTypeName());
}
#endif // GTEST_HAS_DEATH_TEST
TEST(LiteTest, DownCastToGeneratedValidType) {
CastType1 test_type_1;
MessageLite* test_type_1_pointer = &test_type_1;
EXPECT_EQ(&test_type_1, DownCastToGenerated<CastType1>(test_type_1_pointer));
const MessageLite* test_type_1_pointer_const = &test_type_1;
EXPECT_EQ(&test_type_1,
DownCastToGenerated<const CastType1>(test_type_1_pointer_const));
MessageLite* test_type_1_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr,
DownCastToGenerated<CastType1>(test_type_1_pointer_nullptr));
MessageLite& test_type_1_pointer_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DownCastToGenerated<CastType1>(test_type_1_pointer_ref));
const MessageLite& test_type_1_pointer_const_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DownCastToGenerated<CastType1>(test_type_1_pointer_const_ref));
}
#if GTEST_HAS_DEATH_TEST
TEST(LiteTest, DownCastToGeneratedInvalidPointerType) {
CastType1 test_type_1;
MessageLite* test_type_1_pointer = &test_type_1;
ASSERT_DEBUG_DEATH(DownCastToGenerated<CastType2>(test_type_1_pointer),
"Cannot downcast " + test_type_1.GetTypeName() + " to " +
CastType2::default_instance().GetTypeName());
}
TEST(LiteTest, DownCastToGeneratedInvalidReferenceType) {
CastType1 test_type_1;
MessageLite& test_type_1_pointer = test_type_1;
ASSERT_DEBUG_DEATH(DownCastToGenerated<CastType2>(test_type_1_pointer),
"Cannot downcast " + test_type_1.GetTypeName() + " to " +
CastType2::default_instance().GetTypeName());
}
#endif // GTEST_HAS_DEATH_TEST
} // namespace
} // namespace protobuf
} // namespace google

View File

@ -1424,90 +1424,6 @@ DECLARE_GET_REPEATED_FIELD(bool)
#undef DECLARE_GET_REPEATED_FIELD
// Tries to downcast this message to a generated message type. Returns nullptr
// if this class is not an instance of T. This works even if RTTI is disabled.
//
// This also has the effect of creating a strong reference to T that will
// prevent the linker from stripping it out at link time. This can be important
// if you are using a DynamicMessageFactory that delegates to the generated
// factory.
template <typename T>
const T* DynamicCastToGenerated(const Message* from) {
// Compile-time assert that T is a generated type that has a
// default_instance() accessor, but avoid actually calling it.
const T& (*get_default_instance)() = &T::default_instance;
(void)get_default_instance;
// Compile-time assert that T is a subclass of google::protobuf::Message.
const Message* unused = static_cast<T*>(nullptr);
(void)unused;
#if PROTOBUF_RTTI
internal::StrongReferenceToType<T>();
return dynamic_cast<const T*>(from);
#else
bool ok = from != nullptr &&
T::default_instance().GetReflection() == from->GetReflection();
return ok ? internal::DownCast<const T*>(from) : nullptr;
#endif
}
template <typename T>
T* DynamicCastToGenerated(Message* from) {
const Message* message_const = from;
return const_cast<T*>(DynamicCastToGenerated<T>(message_const));
}
// An overloaded version of DynamicCastToGenerated for downcasting references to
// base Message class. If the destination type T if the argument is not an
// instance of T and dynamic_cast returns nullptr, it terminates with an error.
template <typename T>
const T& DynamicCastToGenerated(const Message& from) {
const T* destination_message = DynamicCastToGenerated<T>(&from);
ABSL_CHECK(destination_message != nullptr)
<< "Cannot downcast " << from.GetTypeName() << " to "
<< T::default_instance().GetTypeName();
return *destination_message;
}
template <typename T>
T& DynamicCastToGenerated(Message& from) {
const Message& message_const = from;
const T& destination_message = DynamicCastToGenerated<T>(message_const);
return const_cast<T&>(destination_message);
}
// A lightweight function for downcasting base Message pointer to derived type.
// It should only be used when the caller is certain that the argument is of
// instance T and T is a type derived from base Message class.
template <typename T>
const T* DownCastToGenerated(const Message* from) {
internal::StrongReferenceToType<T>();
ABSL_DCHECK(DynamicCastToGenerated<T>(from) == from)
<< "Cannot downcast " << from->GetTypeName() << " to "
<< T::default_instance().GetTypeName();
return static_cast<const T*>(from);
}
template <typename T>
T* DownCastToGenerated(Message* from) {
const Message* message_const = from;
return const_cast<T*>(DownCastToGenerated<T>(message_const));
}
template <typename T>
const T& DownCastToGenerated(const Message& from) {
return *DownCastToGenerated<T>(&from);
}
template <typename T>
T& DownCastToGenerated(Message& from) {
const Message& message_const = from;
const T& destination_message = DownCastToGenerated<T>(message_const);
return const_cast<T&>(destination_message);
}
// Call this function to ensure that this message's reflection is linked into
// the binary:
//

View File

@ -21,6 +21,7 @@
#include <cstdint>
#include <iosfwd>
#include <string>
#include <type_traits>
#include "absl/base/attributes.h"
#include "absl/log/absl_check.h"
@ -53,6 +54,7 @@ class FastReflectionStringSetter;
class Reflection;
class Descriptor;
class AssignDescriptorsHelper;
class MessageLite;
namespace io {
@ -120,6 +122,9 @@ class PROTOBUF_EXPORT CachedSize {
#endif
};
// For MessageLite to friend.
class TypeId;
class SwapFieldHelper;
// See parse_context.h for explanation
@ -638,6 +643,15 @@ class PROTOBUF_EXPORT MessageLite {
// return a default table instead of a unique one.
virtual const ClassData* GetClassData() const = 0;
template <typename T>
static auto GetClassDataGenerated() {
// We could speed this up if needed by avoiding the function call.
// In LTO this is likely inlined, so it might not matter.
static_assert(
std::is_same<const T&, decltype(T::default_instance())>::value, "");
return T::default_instance().T::GetClassData();
}
internal::InternalMetadata _internal_metadata_;
// Return the cached size object as described by
@ -682,6 +696,7 @@ class PROTOBUF_EXPORT MessageLite {
friend class internal::LazyField;
friend class internal::SwapFieldHelper;
friend class internal::TcParser;
friend class internal::TypeId;
friend class internal::WeakFieldMap;
friend class internal::WireFormatLite;
@ -717,6 +732,32 @@ class PROTOBUF_EXPORT MessageLite {
namespace internal {
// A typeinfo equivalent for protobuf message types. Used for
// DynamicCastToGenerated.
// We might make this class public later on to have an alternative to
// `std::type_info` that works when RTTI is disabled.
class TypeId {
public:
constexpr explicit TypeId(const MessageLite::ClassData* data) : data_(data) {}
friend constexpr bool operator==(TypeId a, TypeId b) {
return a.data_ == b.data_;
}
friend constexpr bool operator!=(TypeId a, TypeId b) { return !(a == b); }
static TypeId Get(const MessageLite& msg) {
return TypeId(msg.GetClassData());
}
template <typename T>
static TypeId Get() {
return TypeId(MessageLite::GetClassDataGenerated<T>());
}
private:
const MessageLite::ClassData* data_;
};
template <bool alias>
bool MergeFromImpl(absl::string_view input, MessageLite* msg,
const internal::TcParseTableBase* tc_table,
@ -820,6 +861,83 @@ T* OnShutdownDelete(T* p) {
std::string ShortFormat(const MessageLite& message_lite);
std::string Utf8Format(const MessageLite& message_lite);
// Tries to downcast this message to a generated message type. Returns nullptr
// if this class is not an instance of T. This works even if RTTI is disabled.
//
// This also has the effect of creating a strong reference to T that will
// prevent the linker from stripping it out at link time. This can be important
// if you are using a DynamicMessageFactory that delegates to the generated
// factory.
template <typename T>
const T* DynamicCastToGenerated(const MessageLite* from) {
static_assert(std::is_base_of<MessageLite, T>::value, "");
internal::StrongReferenceToType<T>();
// We might avoid the call to T::GetClassData() altogether if T were to
// expose the class data pointer.
if (from == nullptr ||
internal::TypeId::Get<T>() != internal::TypeId::Get(*from)) {
return nullptr;
}
return static_cast<const T*>(from);
}
template <typename T>
const T* DynamicCastToGenerated(const MessageLite* from);
template <typename T>
T* DynamicCastToGenerated(MessageLite* from) {
return const_cast<T*>(
DynamicCastToGenerated<T>(static_cast<const MessageLite*>(from)));
}
// An overloaded version of DynamicCastToGenerated for downcasting references to
// base Message class. If the argument is not an instance of T, it terminates
// with an error.
template <typename T>
const T& DynamicCastToGenerated(const MessageLite& from) {
const T* destination_message = DynamicCastToGenerated<T>(&from);
ABSL_CHECK(destination_message != nullptr)
<< "Cannot downcast " << from.GetTypeName() << " to "
<< T::default_instance().GetTypeName();
return *destination_message;
}
template <typename T>
T& DynamicCastToGenerated(MessageLite& from) {
return const_cast<T&>(
DynamicCastToGenerated<T>(static_cast<const MessageLite&>(from)));
}
// A lightweight function for downcasting base MessageLite pointer to derived
// type. It should only be used when the caller is certain that the argument is
// of instance T and T is a generated message type.
template <typename T>
const T* DownCastToGenerated(const MessageLite* from) {
internal::StrongReferenceToType<T>();
ABSL_DCHECK(DynamicCastToGenerated<T>(from) == from)
<< "Cannot downcast " << from->GetTypeName() << " to "
<< T::default_instance().GetTypeName();
return static_cast<const T*>(from);
}
template <typename T>
T* DownCastToGenerated(MessageLite* from) {
return const_cast<T*>(
DownCastToGenerated<T>(static_cast<const MessageLite*>(from)));
}
template <typename T>
const T& DownCastToGenerated(const MessageLite& from) {
return *DownCastToGenerated<T>(&from);
}
template <typename T>
T& DownCastToGenerated(MessageLite& from) {
return *DownCastToGenerated<T>(&from);
}
} // namespace protobuf
} // namespace google

View File

@ -753,35 +753,35 @@ TEST(MESSAGE_TEST_NAME, InitializationErrorString) {
TEST(MESSAGE_TEST_NAME, DynamicCastToGenerated) {
UNITTEST::TestAllTypes test_all_types;
Message* test_all_types_pointer = &test_all_types;
MessageLite* test_all_types_pointer = &test_all_types;
EXPECT_EQ(&test_all_types, DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer));
EXPECT_EQ(nullptr, DynamicCastToGenerated<UNITTEST::TestRequired>(
test_all_types_pointer));
const Message* test_all_types_pointer_const = &test_all_types;
const MessageLite* test_all_types_pointer_const = &test_all_types;
EXPECT_EQ(&test_all_types,
DynamicCastToGenerated<const UNITTEST::TestAllTypes>(
test_all_types_pointer_const));
EXPECT_EQ(nullptr, DynamicCastToGenerated<const UNITTEST::TestRequired>(
test_all_types_pointer_const));
Message* test_all_types_pointer_nullptr = nullptr;
MessageLite* test_all_types_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr, DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_nullptr));
Message& test_all_types_pointer_ref = test_all_types;
MessageLite& test_all_types_pointer_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_ref));
const Message& test_all_types_pointer_const_ref = test_all_types;
const MessageLite& test_all_types_pointer_const_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DynamicCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_const_ref));
}
TEST(MESSAGE_TEST_NAME, DynamicCastToGeneratedInvalidReferenceType) {
UNITTEST::TestAllTypes test_all_types;
const Message& test_all_types_pointer_const_ref = test_all_types;
const MessageLite& test_all_types_pointer_const_ref = test_all_types;
ASSERT_DEATH(DynamicCastToGenerated<UNITTEST::TestRequired>(
test_all_types_pointer_const_ref),
"Cannot downcast " + test_all_types.GetTypeName() + " to " +
@ -791,23 +791,23 @@ TEST(MESSAGE_TEST_NAME, DynamicCastToGeneratedInvalidReferenceType) {
TEST(MESSAGE_TEST_NAME, DownCastToGeneratedValidType) {
UNITTEST::TestAllTypes test_all_types;
Message* test_all_types_pointer = &test_all_types;
MessageLite* test_all_types_pointer = &test_all_types;
EXPECT_EQ(&test_all_types, DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer));
const Message* test_all_types_pointer_const = &test_all_types;
const MessageLite* test_all_types_pointer_const = &test_all_types;
EXPECT_EQ(&test_all_types, DownCastToGenerated<const UNITTEST::TestAllTypes>(
test_all_types_pointer_const));
Message* test_all_types_pointer_nullptr = nullptr;
MessageLite* test_all_types_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr, DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_nullptr));
Message& test_all_types_pointer_ref = test_all_types;
MessageLite& test_all_types_pointer_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_ref));
const Message& test_all_types_pointer_const_ref = test_all_types;
const MessageLite& test_all_types_pointer_const_ref = test_all_types;
EXPECT_EQ(&test_all_types, &DownCastToGenerated<UNITTEST::TestAllTypes>(
test_all_types_pointer_const_ref));
}
@ -815,7 +815,7 @@ TEST(MESSAGE_TEST_NAME, DownCastToGeneratedValidType) {
TEST(MESSAGE_TEST_NAME, DownCastToGeneratedInvalidPointerType) {
UNITTEST::TestAllTypes test_all_types;
Message* test_all_types_pointer = &test_all_types;
MessageLite* test_all_types_pointer = &test_all_types;
ASSERT_DEBUG_DEATH(
DownCastToGenerated<UNITTEST::TestRequired>(test_all_types_pointer),
@ -826,7 +826,7 @@ TEST(MESSAGE_TEST_NAME, DownCastToGeneratedInvalidPointerType) {
TEST(MESSAGE_TEST_NAME, DownCastToGeneratedInvalidReferenceType) {
UNITTEST::TestAllTypes test_all_types;
Message& test_all_types_pointer = test_all_types;
MessageLite& test_all_types_pointer = test_all_types;
ASSERT_DEBUG_DEATH(
DownCastToGenerated<UNITTEST::TestRequired>(test_all_types_pointer),