Nextgen Proto Pythonic API: Add 'in' operator

(Second attempt. The first attempt missed ListValue)

The “in” operator will be consistent with HasField but a little different with Proto Plus.

The detail behavior of “in” operator in Nextgen

* For WKT Struct (to be consist with old Struct behavior):
    -Raise TypeError if not pass a string
    -Check if the key is in the struct.fields

* For WKT ListValue (to be consist with old behavior):
    -Check if the key is in the list_value.values

* For other messages:
    -Raise ValueError if not pass a string
    -Raise ValueError if the string is not a field
    -For Oneof: Check any field under the oneof is set
    -For has-presence field: check if set
    -For non-has-presence field (include repeated fields): raise ValueError

PiperOrigin-RevId: 631143378
pull/16754/head
Jie Luo 2024-05-06 12:10:59 -07:00 committed by Copybara-Service
parent e949bba22a
commit 24f27c3b88
7 changed files with 194 additions and 5 deletions

View File

@ -1336,6 +1336,24 @@ class MessageTest(unittest.TestCase):
union.DESCRIPTOR, message_module.TestAllTypes.NestedEnum.DESCRIPTOR
)
def testIn(self, message_module):
m = message_module.TestAllTypes()
self.assertNotIn('optional_nested_message', m)
self.assertNotIn('oneof_bytes', m)
self.assertNotIn('oneof_string', m)
with self.assertRaises(ValueError) as e:
'repeated_int32' in m
with self.assertRaises(ValueError) as e:
'repeated_nested_message' in m
with self.assertRaises(ValueError) as e:
1 in m
with self.assertRaises(ValueError) as e:
'not_a_field' in m
test_util.SetAllFields(m)
self.assertIn('optional_nested_message', m)
self.assertIn('oneof_bytes', m)
self.assertNotIn('oneof_string', m)
# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase
@ -1367,6 +1385,9 @@ class Proto2Test(unittest.TestCase):
self.assertTrue(message.HasField('optional_int32'))
self.assertTrue(message.HasField('optional_bool'))
self.assertTrue(message.HasField('optional_nested_message'))
self.assertIn('optional_int32', message)
self.assertIn('optional_bool', message)
self.assertIn('optional_nested_message', message)
# Set the fields to non-default values.
message.optional_int32 = 5
@ -1385,6 +1406,9 @@ class Proto2Test(unittest.TestCase):
self.assertFalse(message.HasField('optional_int32'))
self.assertFalse(message.HasField('optional_bool'))
self.assertFalse(message.HasField('optional_nested_message'))
self.assertNotIn('optional_int32', message)
self.assertNotIn('optional_bool', message)
self.assertNotIn('optional_nested_message', message)
self.assertEqual(0, message.optional_int32)
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
@ -1711,6 +1735,12 @@ class Proto3Test(unittest.TestCase):
with self.assertRaises(ValueError):
message.HasField('repeated_nested_message')
# Can not test "in" operator.
with self.assertRaises(ValueError):
'repeated_int32' in message
with self.assertRaises(ValueError):
'repeated_nested_message' in message
# Fields should default to their type-specific default.
self.assertEqual(0, message.optional_int32)
self.assertEqual(0, message.optional_float)
@ -1721,6 +1751,7 @@ class Proto3Test(unittest.TestCase):
# Setting a submessage should still return proper presence information.
message.optional_nested_message.bb = 0
self.assertTrue(message.HasField('optional_nested_message'))
self.assertIn('optional_nested_message', message)
# Set the fields to non-default values.
message.optional_int32 = 5

View File

@ -1000,6 +1000,21 @@ def _AddUnicodeMethod(unused_message_descriptor, cls):
cls.__unicode__ = __unicode__
def _AddContainsMethod(message_descriptor, cls):
if message_descriptor.full_name == 'google.protobuf.Struct':
def __contains__(self, key):
return key in self.fields
elif message_descriptor.full_name == 'google.protobuf.ListValue':
def __contains__(self, value):
return value in self.items()
else:
def __contains__(self, field):
return self.HasField(field)
cls.__contains__ = __contains__
def _BytesForNonRepeatedElement(value, field_number, field_type):
"""Returns the number of bytes needed to serialize a non-repeated element.
The returned byte count includes space for tag information and any
@ -1394,6 +1409,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddStrMethod(message_descriptor, cls)
_AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
_AddContainsMethod(message_descriptor, cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)

View File

@ -497,9 +497,6 @@ class Struct(object):
def __getitem__(self, key):
return _GetStructValue(self.fields[key])
def __contains__(self, item):
return item in self.fields
def __setitem__(self, key, value):
_SetStructValue(self.fields[key], value)

View File

@ -515,7 +515,6 @@ class StructTest(unittest.TestCase):
self.assertEqual(False, struct_list[3])
self.assertEqual(None, struct_list[4])
self.assertEqual(inner_struct, struct_list[5])
self.assertIn(6, struct_list)
struct_list[1] = 7
self.assertEqual(7, struct_list[1])
@ -570,6 +569,36 @@ class StructTest(unittest.TestCase):
self.assertEqual([6, True, False, None, inner_struct],
list(struct['key5'].items()))
def testInOperator(self):
# in operator for Struct
struct = struct_pb2.Struct()
struct['key'] = 5
self.assertIn('key', struct)
self.assertNotIn('fields', struct)
with self.assertRaises(TypeError) as e:
1 in struct
# in operator for ListValue
struct_list = struct.get_or_create_list('key2')
self.assertIsInstance(struct_list, collections_abc.Sequence)
struct_list.extend([6, 'seven', True, False, None])
struct_list.add_struct()['subkey'] = 9
inner_struct = struct.__class__()
inner_struct['subkey'] = 9
self.assertIn(6, struct_list)
self.assertIn('seven', struct_list)
self.assertIn(True, struct_list)
self.assertIn(False, struct_list)
self.assertIn(None, struct_list)
self.assertIn(inner_struct, struct_list)
self.assertNotIn('values', struct_list)
self.assertNotIn(10, struct_list)
for item in struct_list:
self.assertIn(item, struct_list)
def testStructAssignment(self):
# Tests struct assignment from another struct
s1 = struct_pb2.Struct()

View File

@ -75,6 +75,34 @@ class Message(object):
"""Outputs a human-readable representation of the message."""
raise NotImplementedError
def __contains__(self, field_name_or_key):
"""Checks if a certain field is set for the message.
Has presence fields return true if the field is set, false if the field is
not set. Fields without presence do raise `ValueError` (this includes
repeated fields, map fields, and implicit presence fields).
If field_name is not defined in the message descriptor, `ValueError` will
be raised.
Note: WKT Struct checks if the key is contained in fields. ListValue checks
if the item is contained in the list.
Args:
field_name_or_key: For Struct, the key (str) of the fields map. For
ListValue, any type that may be contained in the list. For other
messages, name of the field (str) to check for presence.
Returns:
bool: For Struct, whether the item is contained in fields. For ListValue,
whether the item is contained in the list. For other message,
whether a value has been set for the named field.
Raises:
ValueError: For normal messages, if the `field_name_or_key` is not a
member of this message or `field_name_or_key` is not a string.
"""
raise NotImplementedError
def MergeFrom(self, other_msg):
"""Merges the contents of the specified message into current message.

View File

@ -10,6 +10,7 @@
#include "google/protobuf/pyext/message.h"
#include <Python.h>
#include <structmember.h> // A Python header file.
#include <cstdint>
@ -36,6 +37,7 @@
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map_field.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/unknown_field_set.h"
@ -85,6 +87,12 @@ class MessageReflectionFriend {
return reflection->IsLazyField(field) ||
reflection->IsLazyExtension(message, field);
}
static bool ContainsMapKey(const Reflection* reflection,
const Message& message,
const FieldDescriptor* field,
const MapKey& map_key) {
return reflection->ContainsMapKey(message, field, map_key);
}
};
static PyObject* kDESCRIPTOR;
@ -1293,11 +1301,16 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
char* field_name;
Py_ssize_t size;
field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
Message* message = self->message;
if (!field_name) {
PyErr_Format(PyExc_ValueError,
"The field name passed to message %s"
" is not a str.",
message->GetDescriptor()->name().c_str());
return nullptr;
}
Message* message = self->message;
bool is_in_oneof;
const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
message, absl::string_view(field_name, size), &is_in_oneof);
@ -2290,6 +2303,48 @@ PyObject* ToUnicode(CMessage* self) {
return decoded;
}
PyObject* Contains(CMessage* self, PyObject* arg) {
Message* message = self->message;
const Descriptor* descriptor = message->GetDescriptor();
switch (descriptor->well_known_type()) {
case Descriptor::WELLKNOWNTYPE_STRUCT: {
// For WKT Struct, check if the key is in the fields.
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
const FieldDescriptor* key_field = map_field->message_type()->map_key();
PyObject* py_string = CheckString(arg, key_field);
if (!py_string) {
PyErr_SetString(PyExc_TypeError,
"The key passed to Struct message must be a str.");
return nullptr;
}
char* value;
Py_ssize_t value_len;
if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
Py_DECREF(py_string);
Py_RETURN_FALSE;
}
std::string key_str;
key_str.assign(value, value_len);
Py_DECREF(py_string);
MapKey map_key;
map_key.SetStringValue(key_str);
return PyBool_FromLong(MessageReflectionFriend::ContainsMapKey(
reflection, *message, map_field, map_key));
}
case Descriptor::WELLKNOWNTYPE_LISTVALUE: {
// For WKT ListValue, check if the key is in the items.
PyObject* items = PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
"items", nullptr);
return PyBool_FromLong(PySequence_Contains(items, arg));
}
default:
// For other messages, check with HasField.
return HasField(self, arg);
}
}
// CMessage static methods:
PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
PyObject* unused_arg) {
@ -2338,6 +2393,8 @@ static PyMethodDef Methods[] = {
"Makes a deep copy of the class."},
{"__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
"Outputs a unicode representation of the message."},
{"__contains__", (PyCFunction)Contains, METH_O,
"Checks if a message field is set."},
{"ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)Clear, METH_NOARGS, "Clears the message."},

View File

@ -1044,6 +1044,35 @@ static PyObject* PyUpb_Message_HasField(PyObject* _self, PyObject* arg) {
NULL);
}
static PyObject* PyUpb_Message_Contains(PyObject* _self, PyObject* arg) {
const upb_MessageDef* msgdef = PyUpb_Message_GetMsgdef(_self);
switch (upb_MessageDef_WellKnownType(msgdef)) {
case kUpb_WellKnown_Struct: {
// For WKT Struct, check if the key is in the fields.
PyUpb_Message* self = (void*)_self;
if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
upb_Message* msg = PyUpb_Message_GetMsg(self);
const upb_FieldDef* f = upb_MessageDef_FindFieldByName(msgdef, "fields");
const upb_Map* map = upb_Message_GetFieldByDef(msg, f).map_val;
const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0);
upb_MessageValue u_key;
if (!PyUpb_PyToUpb(arg, key_f, &u_key, NULL)) return NULL;
return PyBool_FromLong(upb_Map_Get(map, u_key, NULL));
}
case kUpb_WellKnown_ListValue: {
// For WKT ListValue, check if the key is in the items.
PyUpb_Message* self = (void*)_self;
if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
PyObject* items = PyObject_CallMethod(_self, "items", NULL);
return PyBool_FromLong(PySequence_Contains(items, arg));
}
default:
// For other messages, check with HasField.
return PyUpb_Message_HasField(_self, arg);
}
}
static PyObject* PyUpb_Message_FindInitializationErrors(PyObject* _self,
PyObject* arg);
@ -1642,6 +1671,8 @@ static PyMethodDef PyUpb_Message_Methods[] = {
// TODO
//{ "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
// "Outputs a unicode representation of the message." },
{"__contains__", PyUpb_Message_Contains, METH_O,
"Checks if a message field is set."},
{"ByteSize", (PyCFunction)PyUpb_Message_ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)PyUpb_Message_Clear, METH_NOARGS,