Nextgen Proto Pythonic API: Add 'in' operator
(Second attempt. The first attempt missed ListValue) The “in” operator will be consistent with HasField but a little different with Proto Plus. The detail behavior of “in” operator in Nextgen * For WKT Struct (to be consist with old Struct behavior): -Raise TypeError if not pass a string -Check if the key is in the struct.fields * For WKT ListValue (to be consist with old behavior): -Check if the key is in the list_value.values * For other messages: -Raise ValueError if not pass a string -Raise ValueError if the string is not a field -For Oneof: Check any field under the oneof is set -For has-presence field: check if set -For non-has-presence field (include repeated fields): raise ValueError PiperOrigin-RevId: 631143378pull/16754/head
parent
e949bba22a
commit
24f27c3b88
|
@ -1336,6 +1336,24 @@ class MessageTest(unittest.TestCase):
|
|||
union.DESCRIPTOR, message_module.TestAllTypes.NestedEnum.DESCRIPTOR
|
||||
)
|
||||
|
||||
def testIn(self, message_module):
|
||||
m = message_module.TestAllTypes()
|
||||
self.assertNotIn('optional_nested_message', m)
|
||||
self.assertNotIn('oneof_bytes', m)
|
||||
self.assertNotIn('oneof_string', m)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
'repeated_int32' in m
|
||||
with self.assertRaises(ValueError) as e:
|
||||
'repeated_nested_message' in m
|
||||
with self.assertRaises(ValueError) as e:
|
||||
1 in m
|
||||
with self.assertRaises(ValueError) as e:
|
||||
'not_a_field' in m
|
||||
test_util.SetAllFields(m)
|
||||
self.assertIn('optional_nested_message', m)
|
||||
self.assertIn('oneof_bytes', m)
|
||||
self.assertNotIn('oneof_string', m)
|
||||
|
||||
|
||||
# Class to test proto2-only features (required, extensions, etc.)
|
||||
@testing_refleaks.TestCase
|
||||
|
@ -1367,6 +1385,9 @@ class Proto2Test(unittest.TestCase):
|
|||
self.assertTrue(message.HasField('optional_int32'))
|
||||
self.assertTrue(message.HasField('optional_bool'))
|
||||
self.assertTrue(message.HasField('optional_nested_message'))
|
||||
self.assertIn('optional_int32', message)
|
||||
self.assertIn('optional_bool', message)
|
||||
self.assertIn('optional_nested_message', message)
|
||||
|
||||
# Set the fields to non-default values.
|
||||
message.optional_int32 = 5
|
||||
|
@ -1385,6 +1406,9 @@ class Proto2Test(unittest.TestCase):
|
|||
self.assertFalse(message.HasField('optional_int32'))
|
||||
self.assertFalse(message.HasField('optional_bool'))
|
||||
self.assertFalse(message.HasField('optional_nested_message'))
|
||||
self.assertNotIn('optional_int32', message)
|
||||
self.assertNotIn('optional_bool', message)
|
||||
self.assertNotIn('optional_nested_message', message)
|
||||
self.assertEqual(0, message.optional_int32)
|
||||
self.assertEqual(False, message.optional_bool)
|
||||
self.assertEqual(0, message.optional_nested_message.bb)
|
||||
|
@ -1711,6 +1735,12 @@ class Proto3Test(unittest.TestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
message.HasField('repeated_nested_message')
|
||||
|
||||
# Can not test "in" operator.
|
||||
with self.assertRaises(ValueError):
|
||||
'repeated_int32' in message
|
||||
with self.assertRaises(ValueError):
|
||||
'repeated_nested_message' in message
|
||||
|
||||
# Fields should default to their type-specific default.
|
||||
self.assertEqual(0, message.optional_int32)
|
||||
self.assertEqual(0, message.optional_float)
|
||||
|
@ -1721,6 +1751,7 @@ class Proto3Test(unittest.TestCase):
|
|||
# Setting a submessage should still return proper presence information.
|
||||
message.optional_nested_message.bb = 0
|
||||
self.assertTrue(message.HasField('optional_nested_message'))
|
||||
self.assertIn('optional_nested_message', message)
|
||||
|
||||
# Set the fields to non-default values.
|
||||
message.optional_int32 = 5
|
||||
|
|
|
@ -1000,6 +1000,21 @@ def _AddUnicodeMethod(unused_message_descriptor, cls):
|
|||
cls.__unicode__ = __unicode__
|
||||
|
||||
|
||||
def _AddContainsMethod(message_descriptor, cls):
|
||||
|
||||
if message_descriptor.full_name == 'google.protobuf.Struct':
|
||||
def __contains__(self, key):
|
||||
return key in self.fields
|
||||
elif message_descriptor.full_name == 'google.protobuf.ListValue':
|
||||
def __contains__(self, value):
|
||||
return value in self.items()
|
||||
else:
|
||||
def __contains__(self, field):
|
||||
return self.HasField(field)
|
||||
|
||||
cls.__contains__ = __contains__
|
||||
|
||||
|
||||
def _BytesForNonRepeatedElement(value, field_number, field_type):
|
||||
"""Returns the number of bytes needed to serialize a non-repeated element.
|
||||
The returned byte count includes space for tag information and any
|
||||
|
@ -1394,6 +1409,7 @@ def _AddMessageMethods(message_descriptor, cls):
|
|||
_AddStrMethod(message_descriptor, cls)
|
||||
_AddReprMethod(message_descriptor, cls)
|
||||
_AddUnicodeMethod(message_descriptor, cls)
|
||||
_AddContainsMethod(message_descriptor, cls)
|
||||
_AddByteSizeMethod(message_descriptor, cls)
|
||||
_AddSerializeToStringMethod(message_descriptor, cls)
|
||||
_AddSerializePartialToStringMethod(message_descriptor, cls)
|
||||
|
|
|
@ -497,9 +497,6 @@ class Struct(object):
|
|||
def __getitem__(self, key):
|
||||
return _GetStructValue(self.fields[key])
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.fields
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
_SetStructValue(self.fields[key], value)
|
||||
|
||||
|
|
|
@ -515,7 +515,6 @@ class StructTest(unittest.TestCase):
|
|||
self.assertEqual(False, struct_list[3])
|
||||
self.assertEqual(None, struct_list[4])
|
||||
self.assertEqual(inner_struct, struct_list[5])
|
||||
self.assertIn(6, struct_list)
|
||||
|
||||
struct_list[1] = 7
|
||||
self.assertEqual(7, struct_list[1])
|
||||
|
@ -570,6 +569,36 @@ class StructTest(unittest.TestCase):
|
|||
self.assertEqual([6, True, False, None, inner_struct],
|
||||
list(struct['key5'].items()))
|
||||
|
||||
def testInOperator(self):
|
||||
# in operator for Struct
|
||||
struct = struct_pb2.Struct()
|
||||
struct['key'] = 5
|
||||
|
||||
self.assertIn('key', struct)
|
||||
self.assertNotIn('fields', struct)
|
||||
with self.assertRaises(TypeError) as e:
|
||||
1 in struct
|
||||
|
||||
# in operator for ListValue
|
||||
struct_list = struct.get_or_create_list('key2')
|
||||
self.assertIsInstance(struct_list, collections_abc.Sequence)
|
||||
struct_list.extend([6, 'seven', True, False, None])
|
||||
struct_list.add_struct()['subkey'] = 9
|
||||
inner_struct = struct.__class__()
|
||||
inner_struct['subkey'] = 9
|
||||
|
||||
self.assertIn(6, struct_list)
|
||||
self.assertIn('seven', struct_list)
|
||||
self.assertIn(True, struct_list)
|
||||
self.assertIn(False, struct_list)
|
||||
self.assertIn(None, struct_list)
|
||||
self.assertIn(inner_struct, struct_list)
|
||||
self.assertNotIn('values', struct_list)
|
||||
self.assertNotIn(10, struct_list)
|
||||
|
||||
for item in struct_list:
|
||||
self.assertIn(item, struct_list)
|
||||
|
||||
def testStructAssignment(self):
|
||||
# Tests struct assignment from another struct
|
||||
s1 = struct_pb2.Struct()
|
||||
|
|
|
@ -75,6 +75,34 @@ class Message(object):
|
|||
"""Outputs a human-readable representation of the message."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __contains__(self, field_name_or_key):
|
||||
"""Checks if a certain field is set for the message.
|
||||
|
||||
Has presence fields return true if the field is set, false if the field is
|
||||
not set. Fields without presence do raise `ValueError` (this includes
|
||||
repeated fields, map fields, and implicit presence fields).
|
||||
|
||||
If field_name is not defined in the message descriptor, `ValueError` will
|
||||
be raised.
|
||||
Note: WKT Struct checks if the key is contained in fields. ListValue checks
|
||||
if the item is contained in the list.
|
||||
|
||||
Args:
|
||||
field_name_or_key: For Struct, the key (str) of the fields map. For
|
||||
ListValue, any type that may be contained in the list. For other
|
||||
messages, name of the field (str) to check for presence.
|
||||
|
||||
Returns:
|
||||
bool: For Struct, whether the item is contained in fields. For ListValue,
|
||||
whether the item is contained in the list. For other message,
|
||||
whether a value has been set for the named field.
|
||||
|
||||
Raises:
|
||||
ValueError: For normal messages, if the `field_name_or_key` is not a
|
||||
member of this message or `field_name_or_key` is not a string.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def MergeFrom(self, other_msg):
|
||||
"""Merges the contents of the specified message into current message.
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "google/protobuf/pyext/message.h"
|
||||
|
||||
#include <Python.h>
|
||||
#include <structmember.h> // A Python header file.
|
||||
|
||||
#include <cstdint>
|
||||
|
@ -36,6 +37,7 @@
|
|||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/strtod.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
|
||||
#include "google/protobuf/map_field.h"
|
||||
#include "google/protobuf/message.h"
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "google/protobuf/unknown_field_set.h"
|
||||
|
@ -85,6 +87,12 @@ class MessageReflectionFriend {
|
|||
return reflection->IsLazyField(field) ||
|
||||
reflection->IsLazyExtension(message, field);
|
||||
}
|
||||
static bool ContainsMapKey(const Reflection* reflection,
|
||||
const Message& message,
|
||||
const FieldDescriptor* field,
|
||||
const MapKey& map_key) {
|
||||
return reflection->ContainsMapKey(message, field, map_key);
|
||||
}
|
||||
};
|
||||
|
||||
static PyObject* kDESCRIPTOR;
|
||||
|
@ -1293,11 +1301,16 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
|
|||
char* field_name;
|
||||
Py_ssize_t size;
|
||||
field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
|
||||
Message* message = self->message;
|
||||
|
||||
if (!field_name) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"The field name passed to message %s"
|
||||
" is not a str.",
|
||||
message->GetDescriptor()->name().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Message* message = self->message;
|
||||
bool is_in_oneof;
|
||||
const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
|
||||
message, absl::string_view(field_name, size), &is_in_oneof);
|
||||
|
@ -2290,6 +2303,48 @@ PyObject* ToUnicode(CMessage* self) {
|
|||
return decoded;
|
||||
}
|
||||
|
||||
PyObject* Contains(CMessage* self, PyObject* arg) {
|
||||
Message* message = self->message;
|
||||
const Descriptor* descriptor = message->GetDescriptor();
|
||||
switch (descriptor->well_known_type()) {
|
||||
case Descriptor::WELLKNOWNTYPE_STRUCT: {
|
||||
// For WKT Struct, check if the key is in the fields.
|
||||
const Reflection* reflection = message->GetReflection();
|
||||
const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
|
||||
const FieldDescriptor* key_field = map_field->message_type()->map_key();
|
||||
PyObject* py_string = CheckString(arg, key_field);
|
||||
if (!py_string) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"The key passed to Struct message must be a str.");
|
||||
return nullptr;
|
||||
}
|
||||
char* value;
|
||||
Py_ssize_t value_len;
|
||||
if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
|
||||
Py_DECREF(py_string);
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
std::string key_str;
|
||||
key_str.assign(value, value_len);
|
||||
Py_DECREF(py_string);
|
||||
|
||||
MapKey map_key;
|
||||
map_key.SetStringValue(key_str);
|
||||
return PyBool_FromLong(MessageReflectionFriend::ContainsMapKey(
|
||||
reflection, *message, map_field, map_key));
|
||||
}
|
||||
case Descriptor::WELLKNOWNTYPE_LISTVALUE: {
|
||||
// For WKT ListValue, check if the key is in the items.
|
||||
PyObject* items = PyObject_CallMethod(reinterpret_cast<PyObject*>(self),
|
||||
"items", nullptr);
|
||||
return PyBool_FromLong(PySequence_Contains(items, arg));
|
||||
}
|
||||
default:
|
||||
// For other messages, check with HasField.
|
||||
return HasField(self, arg);
|
||||
}
|
||||
}
|
||||
|
||||
// CMessage static methods:
|
||||
PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
|
||||
PyObject* unused_arg) {
|
||||
|
@ -2338,6 +2393,8 @@ static PyMethodDef Methods[] = {
|
|||
"Makes a deep copy of the class."},
|
||||
{"__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
|
||||
"Outputs a unicode representation of the message."},
|
||||
{"__contains__", (PyCFunction)Contains, METH_O,
|
||||
"Checks if a message field is set."},
|
||||
{"ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
|
||||
"Returns the size of the message in bytes."},
|
||||
{"Clear", (PyCFunction)Clear, METH_NOARGS, "Clears the message."},
|
||||
|
|
|
@ -1044,6 +1044,35 @@ static PyObject* PyUpb_Message_HasField(PyObject* _self, PyObject* arg) {
|
|||
NULL);
|
||||
}
|
||||
|
||||
static PyObject* PyUpb_Message_Contains(PyObject* _self, PyObject* arg) {
|
||||
const upb_MessageDef* msgdef = PyUpb_Message_GetMsgdef(_self);
|
||||
switch (upb_MessageDef_WellKnownType(msgdef)) {
|
||||
case kUpb_WellKnown_Struct: {
|
||||
// For WKT Struct, check if the key is in the fields.
|
||||
PyUpb_Message* self = (void*)_self;
|
||||
if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
|
||||
upb_Message* msg = PyUpb_Message_GetMsg(self);
|
||||
const upb_FieldDef* f = upb_MessageDef_FindFieldByName(msgdef, "fields");
|
||||
const upb_Map* map = upb_Message_GetFieldByDef(msg, f).map_val;
|
||||
const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
|
||||
const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0);
|
||||
upb_MessageValue u_key;
|
||||
if (!PyUpb_PyToUpb(arg, key_f, &u_key, NULL)) return NULL;
|
||||
return PyBool_FromLong(upb_Map_Get(map, u_key, NULL));
|
||||
}
|
||||
case kUpb_WellKnown_ListValue: {
|
||||
// For WKT ListValue, check if the key is in the items.
|
||||
PyUpb_Message* self = (void*)_self;
|
||||
if (PyUpb_Message_IsStub(self)) Py_RETURN_FALSE;
|
||||
PyObject* items = PyObject_CallMethod(_self, "items", NULL);
|
||||
return PyBool_FromLong(PySequence_Contains(items, arg));
|
||||
}
|
||||
default:
|
||||
// For other messages, check with HasField.
|
||||
return PyUpb_Message_HasField(_self, arg);
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* PyUpb_Message_FindInitializationErrors(PyObject* _self,
|
||||
PyObject* arg);
|
||||
|
||||
|
@ -1642,6 +1671,8 @@ static PyMethodDef PyUpb_Message_Methods[] = {
|
|||
// TODO
|
||||
//{ "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
|
||||
// "Outputs a unicode representation of the message." },
|
||||
{"__contains__", PyUpb_Message_Contains, METH_O,
|
||||
"Checks if a message field is set."},
|
||||
{"ByteSize", (PyCFunction)PyUpb_Message_ByteSize, METH_NOARGS,
|
||||
"Returns the size of the message in bytes."},
|
||||
{"Clear", (PyCFunction)PyUpb_Message_Clear, METH_NOARGS,
|
||||
|
|
Loading…
Reference in New Issue