Down-integrate from google3.

pull/5019/head
Feng Xiao 2018-08-08 17:00:41 -07:00
parent e7746f487c
commit 6bbe197e9c
443 changed files with 30362 additions and 19400 deletions

View File

@ -804,6 +804,7 @@ python_EXTRA_DIST= \
python/google/protobuf/json_format.py \
python/google/protobuf/message.py \
python/google/protobuf/message_factory.py \
python/google/protobuf/python_api.h \
python/google/protobuf/python_protobuf.h \
python/google/protobuf/proto_api.h \
python/google/protobuf/proto_builder.py \

View File

@ -30,7 +30,7 @@
#include <fstream>
#include <iostream>
#include "benchmark/benchmark_api.h"
#include "benchmark/benchmark.h"
#include "benchmarks.pb.h"
#include "datasets/google_message1/proto2/benchmark_message1_proto2.pb.h"
#include "datasets/google_message1/proto3/benchmark_message1_proto3.pb.h"

View File

@ -78,6 +78,7 @@ set(libprotoc_files
${protobuf_source_dir}/src/google/protobuf/compiler/plugin.pb.cc
${protobuf_source_dir}/src/google/protobuf/compiler/python/python_generator.cc
${protobuf_source_dir}/src/google/protobuf/compiler/ruby/ruby_generator.cc
${protobuf_source_dir}/src/google/protobuf/compiler/scc.cc
${protobuf_source_dir}/src/google/protobuf/compiler/subprocess.cc
${protobuf_source_dir}/src/google/protobuf/compiler/zip_writer.cc
)
@ -153,6 +154,7 @@ set(libprotoc_headers
${protobuf_source_dir}/src/google/protobuf/compiler/objectivec/objectivec_message_field.h
${protobuf_source_dir}/src/google/protobuf/compiler/objectivec/objectivec_oneof.h
${protobuf_source_dir}/src/google/protobuf/compiler/objectivec/objectivec_primitive_field.h
${protobuf_source_dir}/src/google/protobuf/compiler/scc.h
${protobuf_source_dir}/src/google/protobuf/compiler/subprocess.h
${protobuf_source_dir}/src/google/protobuf/compiler/zip_writer.h
)
@ -177,3 +179,4 @@ set_target_properties(libprotoc PROPERTIES
OUTPUT_NAME ${LIB_PREFIX}protoc
DEBUG_POSTFIX "${protobuf_DEBUG_POSTFIX}")
add_library(protobuf::libprotoc ALIAS libprotoc)

View File

@ -63,6 +63,7 @@ set(tests_protos
google/protobuf/unittest_optimize_for.proto
google/protobuf/unittest_preserve_unknown_enum.proto
google/protobuf/unittest_preserve_unknown_enum2.proto
google/protobuf/unittest_proto3.proto
google/protobuf/unittest_proto3_arena.proto
google/protobuf/unittest_proto3_arena_lite.proto
google/protobuf/unittest_proto3_lite.proto
@ -78,6 +79,7 @@ set(tests_protos
google/protobuf/util/internal/testdata/struct.proto
google/protobuf/util/internal/testdata/timestamp_duration.proto
google/protobuf/util/internal/testdata/wrappers.proto
google/protobuf/util/json_format.proto
google/protobuf/util/json_format_proto3.proto
google/protobuf/util/message_differencer_unittest.proto
)

View File

@ -233,10 +233,14 @@ class ConformanceJava {
}
case JSON_PAYLOAD: {
try {
TestMessagesProto3.TestAllTypesProto3.Builder builder =
TestMessagesProto3.TestAllTypesProto3.Builder builder =
TestMessagesProto3.TestAllTypesProto3.newBuilder();
JsonFormat.parser().usingTypeRegistry(typeRegistry)
.merge(request.getJsonPayload(), builder);
JsonFormat.Parser parser = JsonFormat.parser().usingTypeRegistry(typeRegistry);
if (request.getTestCategory()
== Conformance.TestCategory.JSON_IGNORE_UNKNOWN_PARSING_TEST) {
parser = parser.ignoringUnknownFields();
}
parser.merge(request.getJsonPayload(), builder);
testMessage = builder.build();
} catch (InvalidProtocolBufferException e) {
return Conformance.ConformanceResponse.newBuilder().setParseError(e.getMessage()).build();

View File

@ -279,6 +279,8 @@ $(protoc_outputs): protoc_middleman
$(other_language_protoc_outputs): protoc_middleman
BUILT_SOURCES = $(protoc_outputs) $(other_language_protoc_outputs)
CLEANFILES = $(protoc_outputs) protoc_middleman javac_middleman conformance-java javac_middleman_lite conformance-java-lite conformance-csharp conformance-php conformance-php-c $(other_language_protoc_outputs)
MAINTAINERCLEANFILES = \

View File

@ -57,6 +57,18 @@ enum WireFormat {
JSON = 2;
}
enum TestCategory {
UNSPECIFIED_TEST = 0;
BINARY_TEST = 1; // Test binary wire format.
JSON_TEST = 2; // Test json wire format.
// Similar to JSON_TEST. However, during parsing json, testee should ignore
// unknown fields. This feature is optional. Each implementation can descide
// whether to support it. See
// https://developers.google.com/protocol-buffers/docs/proto3#json_options
// for more detail.
JSON_IGNORE_UNKNOWN_PARSING_TEST = 3;
}
// Represents a single test case's input. The testee should:
//
// 1. parse this proto (which should always succeed)
@ -83,7 +95,10 @@ message ConformanceRequest {
// protobuf_test_messages.proto2.TestAllTypesProto2.
string message_type = 4;
bool ignore_unknown_json = 5;
// Each test is given a specific test category. Some category may need
// spedific support in testee programs. Refer to the defintion of TestCategory
// for more information.
TestCategory test_category = 5;
}
// Represents a single test case's output.

View File

@ -46,6 +46,7 @@ using google::protobuf::DescriptorPool;
using google::protobuf::Message;
using google::protobuf::MessageFactory;
using google::protobuf::util::BinaryToJsonString;
using google::protobuf::util::JsonParseOptions;
using google::protobuf::util::JsonToBinaryString;
using google::protobuf::util::NewTypeResolverForDescriptorPool;
using google::protobuf::util::Status;
@ -112,8 +113,13 @@ void DoTest(const ConformanceRequest& request, ConformanceResponse* response) {
case ConformanceRequest::kJsonPayload: {
string proto_binary;
JsonParseOptions options;
options.ignore_unknown_fields =
(request.test_category() ==
conformance::JSON_IGNORE_UNKNOWN_PARSING_TEST);
Status status = JsonToBinaryString(type_resolver, *type_url,
request.json_payload(), &proto_binary);
request.json_payload(), &proto_binary,
options);
if (!status.ok()) {
response->set_parse_error(string("Parse error: ") +
status.error_message().as_string());

View File

@ -3,6 +3,7 @@
require_once("Conformance/WireFormat.php");
require_once("Conformance/ConformanceResponse.php");
require_once("Conformance/ConformanceRequest.php");
require_once("Conformance/TestCategory.php");
require_once("Protobuf_test_messages/Proto3/ForeignMessage.php");
require_once("Protobuf_test_messages/Proto3/ForeignEnum.php");
require_once("Protobuf_test_messages/Proto3/TestAllTypesProto3.php");
@ -12,6 +13,7 @@ require_once("Protobuf_test_messages/Proto3/TestAllTypesProto3/NestedEnum.php");
require_once("GPBMetadata/Conformance.php");
require_once("GPBMetadata/Google/Protobuf/TestMessagesProto3.php");
use \Conformance\TestCategory;
use \Conformance\WireFormat;
if (!ini_get("date.timezone")) {
@ -39,7 +41,9 @@ function doTest($request)
trigger_error("Protobuf request doesn't have specific payload type", E_USER_ERROR);
}
} elseif ($request->getPayload() == "json_payload") {
$ignore_json_unknown = $request->getIgnoreUnknownJson();
$ignore_json_unknown =
($request->getTestCategory() ==
TestCategory::JSON_IGNORE_UNKNOWN_PARSING_TEST);
try {
$test_message->mergeFromJsonString($request->getJsonPayload(),
$ignore_json_unknown);

View File

@ -78,7 +78,11 @@ def do_test(request):
elif request.WhichOneof('payload') == 'json_payload':
try:
json_format.Parse(request.json_payload, test_message)
ignore_unknown_fields = \
request.test_category == \
conformance_pb2.JSON_IGNORE_UNKNOWN_PARSING_TEST
json_format.Parse(request.json_payload, test_message,
ignore_unknown_fields)
except Exception as e:
response.parse_error = str(e)
return response

File diff suppressed because it is too large Load Diff

View File

@ -40,12 +40,13 @@
#include <functional>
#include <string>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/util/type_resolver.h>
#include <google/protobuf/wire_format_lite.h>
#include "conformance.pb.h"
#include "third_party/jsoncpp/json.h"
namespace conformance {
class ConformanceRequest;
@ -78,7 +79,23 @@ class ConformanceTestRunner {
};
// Class representing the test suite itself. To run it, implement your own
// class derived from ConformanceTestRunner and then write code like:
// class derived from ConformanceTestRunner, class derived from
// ConformanceTestSuite and then write code like:
//
// class MyConformanceTestSuite : public ConformanceTestSuite {
// public:
// void RunSuiteImpl() {
// // INSERT ACTURAL TESTS.
// }
// };
//
// // Force MyConformanceTestSuite to be added at dynamic initialization
// // time.
// struct StaticTestSuiteInitializer {
// StaticTestSuiteInitializer() {
// AddTestSuite(new MyConformanceTestSuite());
// }
// } static_test_suite_initializer;
//
// class MyConformanceTestRunner : public ConformanceTestRunner {
// public:
@ -89,15 +106,17 @@ class ConformanceTestRunner {
//
// int main() {
// MyConformanceTestRunner runner;
// google::protobuf::ConformanceTestSuite suite;
//
// std::string output;
// suite.RunSuite(&runner, &output);
// const std::set<ConformanceTestSuite*>& test_suite_set =
// ::google::protobuf::GetTestSuiteSet();
// for (auto suite : test_suite_set) {
// suite->RunSuite(&runner, &output);
// }
// }
//
class ConformanceTestSuite {
public:
ConformanceTestSuite() : verbose_(false), enforce_recommended_(false) {}
virtual ~ConformanceTestSuite() {}
void SetVerbose(bool verbose) { verbose_ = verbose; }
@ -130,7 +149,7 @@ class ConformanceTestSuite {
// tests passed.
bool RunSuite(ConformanceTestRunner* runner, std::string* output);
private:
protected:
// Test cases are classified into a few categories:
// REQUIRED: the test case must be passed for an implementation to be
// interoperable with other implementations. For example, a
@ -151,38 +170,43 @@ class ConformanceTestSuite {
class ConformanceRequestSetting {
public:
ConformanceRequestSetting(
ConformanceLevel level, conformance::WireFormat input_format,
conformance::WireFormat output_format, bool is_proto3,
ConformanceLevel level,
conformance::WireFormat input_format,
conformance::WireFormat output_format,
conformance::TestCategory test_category,
const Message& prototype_message,
const string& test_name, const string& input);
virtual ~ConformanceRequestSetting() {}
Message* GetTestMessage() const;
Message* GetTestMessage() const;
const string& GetTestName() const {
return test_name_;
}
string GetTestName() const;
const conformance::ConformanceRequest& GetRequest() const {
return request_;
}
const conformance::ConformanceRequest& GetRequest() const {
return request_;
}
const ConformanceLevel GetLevel() const {
return level_;
}
const ConformanceLevel GetLevel() const {
return level_;
}
string ConformanceLevelToString(ConformanceLevel level) const;
protected:
virtual string InputFormatString(conformance::WireFormat format) const;
virtual string OutputFormatString(conformance::WireFormat format) const;
void SetIgnoreUnknownJson(bool ignore_unknown_json) {
request_.set_ignore_unknown_json(ignore_unknown_json);
}
private:
ConformanceLevel level_;
conformance::WireFormat input_format_;
conformance::WireFormat output_format_;
bool is_proto3_;
::conformance::WireFormat input_format_;
::conformance::WireFormat output_format_;
const Message& prototype_message_;
string test_name_;
conformance::ConformanceRequest request_;
};
static string ConformanceLevelToString(ConformanceLevel level);
bool CheckSetEmpty(const std::set<string>& set_to_check,
const std::string& write_to_file, const std::string& msg);
void ReportSuccess(const std::string& test_name);
void ReportFailure(const string& test_name,
@ -193,73 +217,18 @@ class ConformanceTestSuite {
void ReportSkip(const string& test_name,
const conformance::ConformanceRequest& request,
const conformance::ConformanceResponse& response);
void RunTest(const std::string& test_name,
const conformance::ConformanceRequest& request,
conformance::ConformanceResponse* response);
void RunValidInputTest(const ConformanceRequestSetting& setting,
const string& equivalent_text_format);
void RunValidBinaryInputTest(const ConformanceRequestSetting& setting,
const string& equivalent_wire_format);
void RunValidJsonTest(const string& test_name,
ConformanceLevel level,
const string& input_json,
const string& equivalent_text_format);
void RunValidJsonIgnoreUnknownTest(
const string& test_name, ConformanceLevel level, const string& input_json,
const string& equivalent_text_format);
void RunValidJsonTestWithProtobufInput(
const string& test_name,
ConformanceLevel level,
const protobuf_test_messages::proto3::TestAllTypesProto3& input,
const string& equivalent_text_format);
void RunValidProtobufTest(const string& test_name, ConformanceLevel level,
const string& input_protobuf,
const string& equivalent_text_format,
bool isProto3);
void RunValidBinaryProtobufTest(const string& test_name,
ConformanceLevel level,
const string& input_protobuf,
bool isProto3);
void RunValidProtobufTestWithMessage(
const string& test_name, ConformanceLevel level,
const Message *input,
const string& equivalent_text_format,
bool isProto3);
typedef std::function<bool(const Json::Value&)> Validator;
void RunValidJsonTestWithValidator(const string& test_name,
ConformanceLevel level,
const string& input_json,
const Validator& validator);
void ExpectParseFailureForJson(const string& test_name,
ConformanceLevel level,
const string& input_json);
void ExpectSerializeFailureForJson(const string& test_name,
ConformanceLevel level,
const string& text_format);
void ExpectParseFailureForProtoWithProtoVersion (const string& proto,
const string& test_name,
ConformanceLevel level,
bool isProto3);
void ExpectParseFailureForProto(const std::string& proto,
const std::string& test_name,
ConformanceLevel level);
void ExpectHardParseFailureForProto(const std::string& proto,
const std::string& test_name,
ConformanceLevel level);
void TestPrematureEOFForType(google::protobuf::FieldDescriptor::Type type);
void TestIllegalTags();
template <class MessageType>
void TestOneofMessage (MessageType &message,
bool isProto3);
template <class MessageType>
void TestUnknownMessage (MessageType &message,
bool isProto3);
void TestValidDataForType(
google::protobuf::FieldDescriptor::Type,
std::vector<std::pair<std::string, std::string>> values);
bool CheckSetEmpty(const std::set<string>& set_to_check,
const std::string& write_to_file, const std::string& msg);
void RunTest(const std::string& test_name,
const conformance::ConformanceRequest& request,
conformance::ConformanceResponse* response);
virtual void RunSuiteImpl() = 0;
ConformanceTestRunner* runner_;
int successes_;
int expected_failures_;
@ -285,10 +254,14 @@ class ConformanceTestSuite {
// The set of tests that the testee opted out of;
std::set<std::string> skipped_;
std::unique_ptr<google::protobuf::util::TypeResolver> type_resolver_;
std::unique_ptr<google::protobuf::util::TypeResolver>
type_resolver_;
std::string type_url_;
};
void AddTestSuite(ConformanceTestSuite* suite);
const std::set<ConformanceTestSuite*>& GetTestSuiteSet();
} // namespace protobuf
} // namespace google

File diff suppressed because it is too large Load Diff

View File

@ -66,9 +66,9 @@
#include "conformance.pb.h"
#include "conformance_test.h"
using conformance::ConformanceRequest;
using conformance::ConformanceResponse;
using google::protobuf::StringAppendF;
using google::protobuf::ConformanceTestSuite;
using std::string;
using std::vector;
@ -287,7 +287,8 @@ void ParseFailureList(const char *filename, std::vector<string>* failure_list) {
int main(int argc, char *argv[]) {
char *program;
google::protobuf::ConformanceTestSuite suite;
const std::set<ConformanceTestSuite*>& test_suite_set =
::google::protobuf::GetTestSuiteSet();
string failure_list_filename;
std::vector<string> failure_list;
@ -298,9 +299,13 @@ int main(int argc, char *argv[]) {
failure_list_filename = argv[arg];
ParseFailureList(argv[arg], &failure_list);
} else if (strcmp(argv[arg], "--verbose") == 0) {
suite.SetVerbose(true);
for (auto *suite : test_suite_set) {
suite->SetVerbose(true);
}
} else if (strcmp(argv[arg], "--enforce_recommended") == 0) {
suite.SetEnforceRecommended(true);
for (auto suite : test_suite_set) {
suite->SetEnforceRecommended(true);
}
} else if (argv[arg][0] == '-') {
fprintf(stderr, "Unknown option: %s\n", argv[arg]);
UsageError();
@ -313,11 +318,16 @@ int main(int argc, char *argv[]) {
}
}
suite.SetFailureList(failure_list_filename, failure_list);
for (auto suite : test_suite_set) {
suite->SetFailureList(failure_list_filename, failure_list);
}
ForkPipeRunner runner(program);
std::string output;
bool ok = suite.RunSuite(&runner, &output);
bool ok = true;
for (auto suite : test_suite_set) {
ok &= suite->RunSuite(&runner, &output);
}
fwrite(output.c_str(), 1, output.size(), stderr);

View File

@ -54,9 +54,3 @@ Required.Proto2.ProtobufInput.PrematureEofInPackedField.SINT32
Required.Proto2.ProtobufInput.PrematureEofInPackedField.SINT64
Required.Proto2.ProtobufInput.PrematureEofInPackedField.UINT32
Required.Proto2.ProtobufInput.PrematureEofInPackedField.UINT64
Required.Proto3.JsonInput.IgnoreUnknownJsonFalse.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNull.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNumber.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonObject.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonString.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonTrue.ProtobufOutput

View File

@ -1,8 +1,2 @@
Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput
Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonFalse.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNull.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNumber.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonObject.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonString.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonTrue.ProtobufOutput

View File

@ -45,9 +45,3 @@ Required.Proto3.ProtobufInput.PrematureEofInDelimitedDataForKnownNonRepeatedValu
Required.Proto3.ProtobufInput.PrematureEofInDelimitedDataForKnownRepeatedValue.MESSAGE
Required.Proto2.ProtobufInput.PrematureEofInDelimitedDataForKnownNonRepeatedValue.MESSAGE
Required.Proto2.ProtobufInput.PrematureEofInDelimitedDataForKnownRepeatedValue.MESSAGE
Required.Proto3.JsonInput.IgnoreUnknownJsonFalse.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNull.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNumber.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonObject.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonString.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonTrue.ProtobufOutput

View File

@ -1,6 +1,8 @@
Recommended.FieldMaskNumbersDontRoundTrip.JsonOutput
Recommended.FieldMaskPathsDontRoundTrip.JsonOutput
Recommended.FieldMaskTooManyUnderscore.JsonOutput
Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput
Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput
Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator
Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator
Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator
@ -14,6 +16,11 @@ Recommended.Proto3.JsonInput.StringEndsWithEscapeChar
Recommended.Proto3.JsonInput.StringFieldSurrogateInWrongOrder
Recommended.Proto3.JsonInput.StringFieldUnpairedHighSurrogate
Recommended.Proto3.JsonInput.StringFieldUnpairedLowSurrogate
Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator
Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator
Recommended.Proto3.JsonInput.TimestampHas9FractionalDigits.Validator
Recommended.Proto3.JsonInput.TimestampHasZeroFractionalDigit.Validator
Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator
Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator
Recommended.Proto3.ProtobufInput.OneofZeroBytes.JsonOutput
Required.DurationProtoInputTooLarge.JsonOutput
@ -55,6 +62,7 @@ Required.Proto3.JsonInput.FieldMask.ProtobufOutput
Required.Proto3.JsonInput.FloatFieldInfinity.JsonOutput
Required.Proto3.JsonInput.FloatFieldNan.JsonOutput
Required.Proto3.JsonInput.FloatFieldNegativeInfinity.JsonOutput
Required.Proto3.JsonInput.OneofFieldDuplicate
Required.Proto3.JsonInput.OptionalBoolWrapper.JsonOutput
Required.Proto3.JsonInput.OptionalBoolWrapper.ProtobufOutput
Required.Proto3.JsonInput.OptionalBytesWrapper.JsonOutput
@ -103,6 +111,16 @@ Required.Proto3.JsonInput.StringFieldUnicodeEscapeWithLowercaseHexLetters.JsonOu
Required.Proto3.JsonInput.StringFieldUnicodeEscapeWithLowercaseHexLetters.ProtobufOutput
Required.Proto3.JsonInput.Struct.JsonOutput
Required.Proto3.JsonInput.Struct.ProtobufOutput
Required.Proto3.JsonInput.TimestampMaxValue.JsonOutput
Required.Proto3.JsonInput.TimestampMaxValue.ProtobufOutput
Required.Proto3.JsonInput.TimestampMinValue.JsonOutput
Required.Proto3.JsonInput.TimestampMinValue.ProtobufOutput
Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput
Required.Proto3.JsonInput.TimestampRepeatedValue.ProtobufOutput
Required.Proto3.JsonInput.TimestampWithNegativeOffset.JsonOutput
Required.Proto3.JsonInput.TimestampWithNegativeOffset.ProtobufOutput
Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput
Required.Proto3.JsonInput.TimestampWithPositiveOffset.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput
Required.Proto3.JsonInput.ValueAcceptBool.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput
@ -111,6 +129,8 @@ Required.Proto3.JsonInput.ValueAcceptInteger.JsonOutput
Required.Proto3.JsonInput.ValueAcceptInteger.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptList.JsonOutput
Required.Proto3.JsonInput.ValueAcceptList.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptListWithNull.JsonOutput
Required.Proto3.JsonInput.ValueAcceptListWithNull.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptNull.JsonOutput
Required.Proto3.JsonInput.ValueAcceptNull.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput

View File

@ -1,8 +1,2 @@
JsonInput.StringFieldSurrogateInWrongOrder
JsonInput.StringFieldUnpairedHighSurrogate
Required.Proto3.JsonInput.IgnoreUnknownJsonFalse.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNull.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNumber.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonObject.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonString.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonTrue.ProtobufOutput

View File

@ -19,9 +19,3 @@ Required.Proto3.ProtobufInput.IllegalZeroFieldNum_Case_0
Required.Proto3.ProtobufInput.IllegalZeroFieldNum_Case_1
Required.Proto3.ProtobufInput.IllegalZeroFieldNum_Case_2
Required.Proto3.ProtobufInput.IllegalZeroFieldNum_Case_3
Required.Proto3.JsonInput.IgnoreUnknownJsonFalse.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNull.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNumber.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonObject.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonString.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonTrue.ProtobufOutput

View File

@ -52,9 +52,3 @@ Required.Proto2.ProtobufInput.PrematureEofInPackedField.SINT32
Required.Proto2.ProtobufInput.PrematureEofInPackedField.SINT64
Required.Proto2.ProtobufInput.PrematureEofInPackedField.UINT32
Required.Proto2.ProtobufInput.PrematureEofInPackedField.UINT64
Required.Proto3.JsonInput.IgnoreUnknownJsonFalse.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNull.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNumber.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonObject.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonString.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonTrue.ProtobufOutput

View File

@ -120,8 +120,6 @@ Required.Proto3.JsonInput.ValueAcceptInteger.JsonOutput
Required.Proto3.JsonInput.ValueAcceptInteger.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptList.JsonOutput
Required.Proto3.JsonInput.ValueAcceptList.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptListWithNull.JsonOutput
Required.Proto3.JsonInput.ValueAcceptListWithNull.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptNull.JsonOutput
Required.Proto3.JsonInput.ValueAcceptNull.ProtobufOutput
Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput
@ -135,9 +133,3 @@ Required.Proto3.ProtobufInput.FloatFieldNormalizeSignalingNan.JsonOutput
Required.Proto3.ProtobufInput.ValidDataRepeated.FLOAT.JsonOutput
Required.TimestampProtoInputTooLarge.JsonOutput
Required.TimestampProtoInputTooSmall.JsonOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonFalse.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNull.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonNumber.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonObject.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonString.ProtobufOutput
Required.Proto3.JsonInput.IgnoreUnknownJsonTrue.ProtobufOutput

View File

@ -32,7 +32,6 @@ package com.google.protobuf;
import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor.Syntax;
import com.google.protobuf.Descriptors.OneofDescriptor;
import com.google.protobuf.Internal.EnumLite;
import java.io.IOException;
@ -446,10 +445,7 @@ public abstract class AbstractMessage
final CodedInputStream input,
final ExtensionRegistryLite extensionRegistry)
throws IOException {
boolean discardUnknown =
getDescriptorForType().getFile().getSyntax() == Syntax.PROTO3
? input.shouldDiscardUnknownFieldsProto3()
: input.shouldDiscardUnknownFields();
boolean discardUnknown = input.shouldDiscardUnknownFields();
final UnknownFieldSet.Builder unknownFields =
discardUnknown ? null : UnknownFieldSet.newBuilder(getUnknownFields());
while (true) {

View File

@ -42,7 +42,8 @@ import java.util.RandomAccess;
*
* @author dweis@google.com (Daniel Weis)
*/
final class BooleanArrayList extends AbstractProtobufList<Boolean>
final class BooleanArrayList
extends AbstractProtobufList<Boolean>
implements BooleanList, RandomAccess, PrimitiveNonBoxingCollection {
private static final BooleanArrayList EMPTY_LIST = new BooleanArrayList();

View File

@ -46,6 +46,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
@ -221,6 +222,67 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
return size() == 0;
}
// =================================================================
// Comparison
private static final int UNSIGNED_BYTE_MASK = 0xFF;
/**
* Returns the value of the given byte as an integer, interpreting the byte as an unsigned value.
* That is, returns {@code value + 256} if {@code value} is negative; {@code value} itself
* otherwise.
*
* <p>Note: This code was copied from {@link com.google.common.primitives.UnsignedBytes#toInt}, as
* Guava libraries cannot be used in the {@code com.google.protobuf} package.
*/
private static int toInt(byte value) {
return value & UNSIGNED_BYTE_MASK;
}
/**
* Compares two {@link ByteString}s lexicographically, treating their contents as unsigned byte
* values between 0 and 255 (inclusive).
*
* <p>For example, {@code (byte) -1} is considered to be greater than {@code (byte) 1} because
* it is interpreted as an unsigned value, {@code 255}.
*/
private static final Comparator<ByteString> UNSIGNED_LEXICOGRAPHICAL_COMPARATOR =
new Comparator<ByteString>() {
@Override
public int compare(ByteString former, ByteString latter) {
ByteIterator formerBytes = former.iterator();
ByteIterator latterBytes = latter.iterator();
while (formerBytes.hasNext() && latterBytes.hasNext()) {
// Note: This code was copied from com.google.common.primitives.UnsignedBytes#compare,
// as Guava libraries cannot be used in the {@code com.google.protobuf} package.
int result =
Integer.compare(toInt(formerBytes.nextByte()), toInt(latterBytes.nextByte()));
if (result != 0) {
return result;
}
}
return Integer.compare(former.size(), latter.size());
}
};
/**
* Returns a {@link Comparator<ByteString>} which compares {@link ByteString}-s lexicographically
* as sequences of unsigned bytes (i.e. values between 0 and 255, inclusive).
*
* <p>For example, {@code (byte) -1} is considered to be greater than {@code (byte) 1} because
* it is interpreted as an unsigned value, {@code 255}:
*
* <ul>
* <li>{@code `-1` -> 0b11111111 (two's complement) -> 255}
* <li>{@code `1` -> 0b00000001 -> 1}
* </ul>
*/
public static Comparator<ByteString> unsignedLexicographicalComparator() {
return UNSIGNED_LEXICOGRAPHICAL_COMPARATOR;
}
// =================================================================
// ByteString -> substring
@ -287,8 +349,10 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
* @param offset offset in source array
* @param size number of bytes to copy
* @return new {@code ByteString}
* @throws IndexOutOfBoundsException if {@code offset} or {@code size} are out of bounds
*/
public static ByteString copyFrom(byte[] bytes, int offset, int size) {
checkRange(offset, offset + size, bytes.length);
return new LiteralByteString(byteArrayCopier.copyFrom(bytes, offset, size));
}
@ -339,8 +403,10 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
* @param bytes source buffer
* @param size number of bytes to copy
* @return new {@code ByteString}
* @throws IndexOutOfBoundsException if {@code size > bytes.remaining()}
*/
public static ByteString copyFrom(ByteBuffer bytes, int size) {
checkRange(0, size, bytes.remaining());
byte[] copy = new byte[size];
bytes.get(copy);
return new LiteralByteString(copy);
@ -578,6 +644,9 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
/**
* Copies bytes into a buffer at the given offset.
*
* <p>To copy a subset of bytes, you call this method on the return value of {@link
* #substring(int, int)}. Example: {@code byteString.substring(start, end).copyTo(target, offset)}
*
* @param target buffer to copy into
* @param offset in the target buffer
* @throws IndexOutOfBoundsException if the offset is negative or too large
@ -589,15 +658,16 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
/**
* Copies bytes into a buffer.
*
* @param target buffer to copy into
* @param target buffer to copy into
* @param sourceOffset offset within these bytes
* @param targetOffset offset within the target buffer
* @param numberToCopy number of bytes to copy
* @throws IndexOutOfBoundsException if an offset or size is negative or too
* large
* @throws IndexOutOfBoundsException if an offset or size is negative or too large
* @deprecation Instead, call {@code byteString.substring(sourceOffset, sourceOffset +
* numberToCopy).copyTo(target, targetOffset)}
*/
public final void copyTo(byte[] target, int sourceOffset, int targetOffset,
int numberToCopy) {
@Deprecated
public final void copyTo(byte[] target, int sourceOffset, int targetOffset, int numberToCopy) {
checkRange(sourceOffset, sourceOffset + numberToCopy, size());
checkRange(targetOffset, targetOffset + numberToCopy, target.length);
if (numberToCopy > 0) {
@ -617,10 +687,13 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
/**
* Copies bytes into a ByteBuffer.
*
* <p>To copy a subset of bytes, you call this method on the return value of {@link
* #substring(int, int)}. Example: {@code byteString.substring(start, end).copyTo(target)}
*
* @param target ByteBuffer to copy into.
* @throws java.nio.ReadOnlyBufferException if the {@code target} is read-only
* @throws java.nio.BufferOverflowException if the {@code target}'s
* remaining() space is not large enough to hold the data.
* @throws java.nio.BufferOverflowException if the {@code target}'s remaining() space is not large
* enough to hold the data.
*/
public abstract void copyTo(ByteBuffer target);
@ -1258,6 +1331,9 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
* @param bytes array to wrap
*/
LiteralByteString(byte[] bytes) {
if (bytes == null) {
throw new NullPointerException();
}
this.bytes = bytes;
}

View File

@ -64,12 +64,6 @@ public abstract class CodedInputStream {
// Integer.MAX_VALUE == 0x7FFFFFF == INT_MAX from limits.h
private static final int DEFAULT_SIZE_LIMIT = Integer.MAX_VALUE;
/**
* Whether to enable our custom UTF-8 decode codepath which does not use {@link StringCoding}.
* Currently disabled.
*/
private static final boolean ENABLE_CUSTOM_UTF8_DECODE = false;
/** Visible for subclasses. See setRecursionLimit() */
int recursionDepth;
@ -417,21 +411,7 @@ public abstract class CodedInputStream {
}
private boolean explicitDiscardUnknownFields = false;
private static volatile boolean proto3DiscardUnknownFieldsDefault = false;
static void setProto3DiscardUnknownsByDefaultForTest() {
proto3DiscardUnknownFieldsDefault = true;
}
static void setProto3KeepUnknownsByDefaultForTest() {
proto3DiscardUnknownFieldsDefault = false;
}
static boolean getProto3DiscardUnknownFieldsDefault() {
return proto3DiscardUnknownFieldsDefault;
}
private boolean shouldDiscardUnknownFields = false;
/**
* Sets this {@code CodedInputStream} to discard unknown fields. Only applies to full runtime
@ -442,7 +422,7 @@ public abstract class CodedInputStream {
* runtime.
*/
final void discardUnknownFields() {
explicitDiscardUnknownFields = true;
shouldDiscardUnknownFields = true;
}
/**
@ -450,7 +430,7 @@ public abstract class CodedInputStream {
* default.
*/
final void unsetDiscardUnknownFields() {
explicitDiscardUnknownFields = false;
shouldDiscardUnknownFields = false;
}
/**
@ -458,19 +438,7 @@ public abstract class CodedInputStream {
* runtime messages.
*/
final boolean shouldDiscardUnknownFields() {
return explicitDiscardUnknownFields;
}
/**
* Whether unknown fields in this input stream should be discarded during parsing for proto3 full
* runtime messages.
*
* <p>This function was temporarily introduced before proto3 unknown fields behavior is changed.
* TODO(liujisi): remove this and related code in GeneratedMessage after proto3 unknown
* fields migration is done.
*/
final boolean shouldDiscardUnknownFieldsProto3() {
return explicitDiscardUnknownFields ? true : proto3DiscardUnknownFieldsDefault;
return shouldDiscardUnknownFields;
}
/**
@ -831,19 +799,9 @@ public abstract class CodedInputStream {
public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32();
if (size > 0 && size <= (limit - pos)) {
if (ENABLE_CUSTOM_UTF8_DECODE) {
String result = Utf8.decodeUtf8(buffer, pos, size);
pos += size;
return result;
} else {
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(buffer, pos, pos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
final int tempPos = pos;
pos += size;
return new String(buffer, tempPos, size, UTF_8);
}
String result = Utf8.decodeUtf8(buffer, pos, size);
pos += size;
return result;
}
if (size == 0) {
@ -1559,25 +1517,10 @@ public abstract class CodedInputStream {
public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32();
if (size > 0 && size <= remaining()) {
if (ENABLE_CUSTOM_UTF8_DECODE) {
final int bufferPos = bufferPos(pos);
String result = Utf8.decodeUtf8(buffer, bufferPos, size);
pos += size;
return result;
} else {
// TODO(nathanmittler): Is there a way to avoid this copy?
// The same as readBytes' logic
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(pos, bytes, 0, size);
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
pos += size;
return result;
}
final int bufferPos = bufferPos(pos);
String result = Utf8.decodeUtf8(buffer, bufferPos, size);
pos += size;
return result;
}
if (size == 0) {
@ -2345,15 +2288,7 @@ public abstract class CodedInputStream {
bytes = readRawBytesSlowPath(size);
tempPos = 0;
}
if (ENABLE_CUSTOM_UTF8_DECODE) {
return Utf8.decodeUtf8(bytes, tempPos, size);
} else {
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
return new String(bytes, tempPos, size, UTF_8);
}
return Utf8.decodeUtf8(bytes, tempPos, size);
}
@Override
@ -3373,34 +3308,15 @@ public abstract class CodedInputStream {
public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32();
if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) {
if (ENABLE_CUSTOM_UTF8_DECODE) {
final int bufferPos = (int) (currentByteBufferPos - currentByteBufferStartPos);
String result = Utf8.decodeUtf8(currentByteBuffer, bufferPos, size);
currentByteBufferPos += size;
return result;
} else {
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size);
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
currentByteBufferPos += size;
return result;
}
final int bufferPos = (int) (currentByteBufferPos - currentByteBufferStartPos);
String result = Utf8.decodeUtf8(currentByteBuffer, bufferPos, size);
currentByteBufferPos += size;
return result;
}
if (size >= 0 && size <= remaining()) {
byte[] bytes = new byte[size];
readRawBytesTo(bytes, 0, size);
if (ENABLE_CUSTOM_UTF8_DECODE) {
return Utf8.decodeUtf8(bytes, 0, size);
} else {
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
return result;
}
return Utf8.decodeUtf8(bytes, 0, size);
}
if (size == 0) {

View File

@ -42,7 +42,8 @@ import java.util.RandomAccess;
*
* @author dweis@google.com (Daniel Weis)
*/
final class DoubleArrayList extends AbstractProtobufList<Double>
final class DoubleArrayList
extends AbstractProtobufList<Double>
implements DoubleList, RandomAccess, PrimitiveNonBoxingCollection {
private static final DoubleArrayList EMPTY_LIST = new DoubleArrayList();

View File

@ -608,20 +608,12 @@ public final class DynamicMessage extends AbstractMessage {
@Override
public Builder setUnknownFields(UnknownFieldSet unknownFields) {
if (getDescriptorForType().getFile().getSyntax() == Descriptors.FileDescriptor.Syntax.PROTO3
&& CodedInputStream.getProto3DiscardUnknownFieldsDefault()) {
return this;
}
this.unknownFields = unknownFields;
return this;
}
@Override
public Builder mergeUnknownFields(UnknownFieldSet unknownFields) {
if (getDescriptorForType().getFile().getSyntax() == Descriptors.FileDescriptor.Syntax.PROTO3
&& CodedInputStream.getProto3DiscardUnknownFieldsDefault()) {
return this;
}
this.unknownFields =
UnknownFieldSet.newBuilder(this.unknownFields)
.mergeFrom(unknownFields)

View File

@ -42,7 +42,8 @@ import java.util.RandomAccess;
*
* @author dweis@google.com (Daniel Weis)
*/
final class FloatArrayList extends AbstractProtobufList<Float>
final class FloatArrayList
extends AbstractProtobufList<Float>
implements FloatList, RandomAccess, PrimitiveNonBoxingCollection {
private static final FloatArrayList EMPTY_LIST = new FloatArrayList();

View File

@ -33,7 +33,6 @@ package com.google.protobuf;
import com.google.protobuf.AbstractMessageLite.Builder.LimitedInputStream;
import com.google.protobuf.Internal.BooleanList;
import com.google.protobuf.Internal.DoubleList;
import com.google.protobuf.Internal.EnumLiteMap;
import com.google.protobuf.Internal.FloatList;
import com.google.protobuf.Internal.IntList;
import com.google.protobuf.Internal.LongList;
@ -1600,7 +1599,7 @@ public abstract class GeneratedMessageLite<
protected static class DefaultInstanceBasedParser<T extends GeneratedMessageLite<T, ?>>
extends AbstractParser<T> {
private T defaultInstance;
private final T defaultInstance;
public DefaultInstanceBasedParser(T defaultInstance) {
this.defaultInstance = defaultInstance;

View File

@ -38,6 +38,11 @@ import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.OneofDescriptor;
import com.google.protobuf.Internal.BooleanList;
import com.google.protobuf.Internal.DoubleList;
import com.google.protobuf.Internal.FloatList;
import com.google.protobuf.Internal.IntList;
import com.google.protobuf.Internal.LongList;
// In opensource protobuf, we have versioned this GeneratedMessageV3 class to GeneratedMessageV3V3 and
// in the future may have GeneratedMessageV3V4 etc. This allows us to change some aspects of this
// class without breaking binary compatibility with old generated code that still subclasses
@ -293,16 +298,17 @@ public abstract class GeneratedMessageV3 extends AbstractMessage
return unknownFields.mergeFieldFrom(tag, input);
}
/**
* Delegates to parseUnknownField. This method is obsolete, but we must retain it for
* compatibility with older generated code.
*/
protected boolean parseUnknownFieldProto3(
CodedInputStream input,
UnknownFieldSet.Builder unknownFields,
ExtensionRegistryLite extensionRegistry,
int tag)
throws IOException {
if (input.shouldDiscardUnknownFieldsProto3()) {
return input.skipField(tag);
}
return unknownFields.mergeFieldFrom(tag, input);
return parseUnknownField(input, unknownFields, extensionRegistry, tag);
}
protected static <M extends Message> M parseWithIOException(Parser<M> parser, InputStream input)
@ -363,6 +369,76 @@ public abstract class GeneratedMessageV3 extends AbstractMessage
return UnsafeUtil.hasUnsafeArrayOperations() && UnsafeUtil.hasUnsafeByteBufferOperations();
}
protected static IntList emptyIntList() {
return IntArrayList.emptyList();
}
protected static IntList newIntList() {
return new IntArrayList();
}
protected static IntList mutableCopy(IntList list) {
int size = list.size();
return list.mutableCopyWithCapacity(
size == 0 ? AbstractProtobufList.DEFAULT_CAPACITY : size * 2);
}
protected static LongList emptyLongList() {
return LongArrayList.emptyList();
}
protected static LongList newLongList() {
return new LongArrayList();
}
protected static LongList mutableCopy(LongList list) {
int size = list.size();
return list.mutableCopyWithCapacity(
size == 0 ? AbstractProtobufList.DEFAULT_CAPACITY : size * 2);
}
protected static FloatList emptyFloatList() {
return FloatArrayList.emptyList();
}
protected static FloatList newFloatList() {
return new FloatArrayList();
}
protected static FloatList mutableCopy(FloatList list) {
int size = list.size();
return list.mutableCopyWithCapacity(
size == 0 ? AbstractProtobufList.DEFAULT_CAPACITY : size * 2);
}
protected static DoubleList emptyDoubleList() {
return DoubleArrayList.emptyList();
}
protected static DoubleList newDoubleList() {
return new DoubleArrayList();
}
protected static DoubleList mutableCopy(DoubleList list) {
int size = list.size();
return list.mutableCopyWithCapacity(
size == 0 ? AbstractProtobufList.DEFAULT_CAPACITY : size * 2);
}
protected static BooleanList emptyBooleanList() {
return BooleanArrayList.emptyList();
}
protected static BooleanList newBooleanList() {
return new BooleanArrayList();
}
protected static BooleanList mutableCopy(BooleanList list) {
int size = list.size();
return list.mutableCopyWithCapacity(
size == 0 ? AbstractProtobufList.DEFAULT_CAPACITY : size * 2);
}
@Override
public void writeTo(final CodedOutputStream output) throws IOException {
MessageReflection.writeMessageTo(this, getAllFieldsRaw(), output, false);
@ -641,13 +717,12 @@ public abstract class GeneratedMessageV3 extends AbstractMessage
return (BuilderType) this;
}
/**
* Delegates to setUnknownFields. This method is obsolete, but we must retain it for
* compatibility with older generated code.
*/
protected BuilderType setUnknownFieldsProto3(final UnknownFieldSet unknownFields) {
if (CodedInputStream.getProto3DiscardUnknownFieldsDefault()) {
return (BuilderType) this;
}
this.unknownFields = unknownFields;
onChanged();
return (BuilderType) this;
return setUnknownFields(unknownFields);
}
@Override
@ -1009,19 +1084,17 @@ public abstract class GeneratedMessageV3 extends AbstractMessage
getDescriptorForType(), new MessageReflection.ExtensionAdapter(extensions), tag);
}
/**
* Delegates to parseUnknownField. This method is obsolete, but we must retain it for
* compatibility with older generated code.
*/
@Override
protected boolean parseUnknownFieldProto3(
CodedInputStream input,
UnknownFieldSet.Builder unknownFields,
ExtensionRegistryLite extensionRegistry,
int tag) throws IOException {
return MessageReflection.mergeFieldFrom(
input,
input.shouldDiscardUnknownFieldsProto3() ? null : unknownFields,
extensionRegistry,
getDescriptorForType(),
new MessageReflection.ExtensionAdapter(extensions),
tag);
return parseUnknownField(input, unknownFields, extensionRegistry, tag);
}

View File

@ -42,7 +42,8 @@ import java.util.RandomAccess;
*
* @author dweis@google.com (Daniel Weis)
*/
final class IntArrayList extends AbstractProtobufList<Integer>
final class IntArrayList
extends AbstractProtobufList<Integer>
implements IntList, RandomAccess, PrimitiveNonBoxingCollection {
private static final IntArrayList EMPTY_LIST = new IntArrayList();

View File

@ -234,6 +234,11 @@ public final class Internal {
T findValueByNumber(int number);
}
/** Interface for an object which verifies integers are in range. */
public interface EnumVerifier {
boolean isInRange(int number);
}
/**
* Helper method for implementing {@link Message#hashCode()} for longs.
* @see Long#hashCode()

View File

@ -69,7 +69,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
static {
EMPTY_LIST.makeImmutable();
}
static LazyStringArrayList emptyList() {
return EMPTY_LIST;
}
@ -83,8 +83,8 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
this(DEFAULT_CAPACITY);
}
public LazyStringArrayList(int intialCapacity) {
this(new ArrayList<Object>(intialCapacity));
public LazyStringArrayList(int initialCapacity) {
this(new ArrayList<Object>(initialCapacity));
}
public LazyStringArrayList(LazyStringList from) {
@ -95,7 +95,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
public LazyStringArrayList(List<String> from) {
this(new ArrayList<Object>(from));
}
private LazyStringArrayList(ArrayList<Object> list) {
this.list = list;
}
@ -150,13 +150,13 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
list.add(index, element);
modCount++;
}
private void add(int index, ByteString element) {
ensureIsMutable();
list.add(index, element);
modCount++;
}
private void add(int index, byte[] element) {
ensureIsMutable();
list.add(index, element);
@ -221,7 +221,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
list.add(element);
modCount++;
}
@Override
public void add(byte[] element) {
ensureIsMutable();
@ -233,7 +233,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
public Object getRaw(int index) {
return list.get(index);
}
@Override
public ByteString getByteString(int index) {
Object o = list.get(index);
@ -243,7 +243,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
}
return b;
}
@Override
public byte[] getByteArray(int index) {
Object o = list.get(index);
@ -258,7 +258,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
public void set(int index, ByteString s) {
setAndReturn(index, s);
}
private Object setAndReturn(int index, ByteString s) {
ensureIsMutable();
return list.set(index, s);
@ -268,7 +268,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
public void set(int index, byte[] s) {
setAndReturn(index, s);
}
private Object setAndReturn(int index, byte[] s) {
ensureIsMutable();
return list.set(index, s);
@ -283,7 +283,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
return Internal.toStringUtf8((byte[]) o);
}
}
private static ByteString asByteString(Object o) {
if (o instanceof ByteString) {
return (ByteString) o;
@ -293,7 +293,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
return ByteString.copyFrom((byte[]) o);
}
}
private static byte[] asByteArray(Object o) {
if (o instanceof byte[]) {
return (byte[]) o;
@ -327,11 +327,11 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
private static class ByteArrayListView extends AbstractList<byte[]>
implements RandomAccess {
private final LazyStringArrayList list;
ByteArrayListView(LazyStringArrayList list) {
this.list = list;
}
@Override
public byte[] get(int index) {
return list.getByteArray(index);
@ -362,7 +362,7 @@ public class LazyStringArrayList extends AbstractProtobufList<String>
return asByteArray(o);
}
}
@Override
public List<byte[]> asByteArrayList() {
return new ByteArrayListView(this);

View File

@ -42,7 +42,8 @@ import java.util.RandomAccess;
*
* @author dweis@google.com (Daniel Weis)
*/
final class LongArrayList extends AbstractProtobufList<Long>
final class LongArrayList
extends AbstractProtobufList<Long>
implements LongList, RandomAccess, PrimitiveNonBoxingCollection {
private static final LongArrayList EMPTY_LIST = new LongArrayList();

View File

@ -30,6 +30,8 @@
package com.google.protobuf;
import static com.google.protobuf.Internal.checkNotNull;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;
@ -290,9 +292,7 @@ public class RepeatedFieldBuilder
*/
public RepeatedFieldBuilder<MType, BType, IType> setMessage(
int index, MType message) {
if (message == null) {
throw new NullPointerException();
}
checkNotNull(message);
ensureMutableMessageList();
messages.set(index, message);
if (builders != null) {
@ -315,9 +315,7 @@ public class RepeatedFieldBuilder
*/
public RepeatedFieldBuilder<MType, BType, IType> addMessage(
MType message) {
if (message == null) {
throw new NullPointerException();
}
checkNotNull(message);
ensureMutableMessageList();
messages.add(message);
if (builders != null) {
@ -339,9 +337,7 @@ public class RepeatedFieldBuilder
*/
public RepeatedFieldBuilder<MType, BType, IType> addMessage(
int index, MType message) {
if (message == null) {
throw new NullPointerException();
}
checkNotNull(message);
ensureMutableMessageList();
messages.add(index, message);
if (builders != null) {
@ -363,9 +359,7 @@ public class RepeatedFieldBuilder
public RepeatedFieldBuilder<MType, BType, IType> addAllMessages(
Iterable<? extends MType> values) {
for (final MType value : values) {
if (value == null) {
throw new NullPointerException();
}
checkNotNull(value);
}
// If we can inspect the size, we can more efficiently add messages.

View File

@ -30,6 +30,8 @@
package com.google.protobuf;
import static com.google.protobuf.Internal.checkNotNull;
/**
* {@code SingleFieldBuilder} implements a structure that a protocol
* message uses to hold a single field of another protocol message. It supports
@ -84,10 +86,7 @@ public class SingleFieldBuilder
MType message,
GeneratedMessage.BuilderParent parent,
boolean isClean) {
if (message == null) {
throw new NullPointerException();
}
this.message = message;
this.message = checkNotNull(message);
this.parent = parent;
this.isClean = isClean;
}
@ -169,10 +168,7 @@ public class SingleFieldBuilder
*/
public SingleFieldBuilder<MType, BType, IType> setMessage(
MType message) {
if (message == null) {
throw new NullPointerException();
}
this.message = message;
this.message = checkNotNull(message);
if (builder != null) {
builder.dispose();
builder = null;

View File

@ -1444,8 +1444,8 @@ public final class TextFormat {
logger.warning(msg.toString());
} else {
String[] lineColumn = unknownFields.get(0).split(":");
throw new ParseException(Integer.valueOf(lineColumn[0]),
Integer.valueOf(lineColumn[1]), msg.toString());
throw new ParseException(
Integer.parseInt(lineColumn[0]), Integer.parseInt(lineColumn[1]), msg.toString());
}
}

View File

@ -785,6 +785,23 @@ public final class UnknownFieldSet implements MessageLite {
group};
}
/**
* Serializes the message to a {@code ByteString} and returns it. This is just a trivial wrapper
* around {@link #writeTo(int, CodedOutputStream)}.
*/
public ByteString toByteString(int fieldNumber) {
try {
// TODO(lukes): consider caching serialized size in a volatile long
final ByteString.CodedBuilder out =
ByteString.newCodedBuilder(getSerializedSize(fieldNumber));
writeTo(fieldNumber, out.getCodedOutput());
return out.build();
} catch (IOException e) {
throw new RuntimeException(
"Serializing to a ByteString should never fail with an IOException", e);
}
}
/**
* Serializes the field, including field number, and writes it to
* {@code output}.

View File

@ -33,6 +33,7 @@ package com.google.protobuf;
import java.lang.reflect.Field;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.logging.Level;
@ -146,6 +147,10 @@ final class UnsafeUtil {
return MEMORY_ACCESSOR.getObject(target, offset);
}
static void putObject(Object target, long offset, Object value) {
MEMORY_ACCESSOR.putObject(target, offset, value);
}
static byte getByte(byte[] target, long index) {
return MEMORY_ACCESSOR.getByte(target, BYTE_ARRAY_BASE_OFFSET + index);
}
@ -370,12 +375,6 @@ final class UnsafeUtil {
return field != null && field.getType() == long.class ? field : null;
}
/** Finds the value field within a {@link String}. */
private static Field stringValueField() {
Field field = field(String.class, "value");
return field != null && field.getType() == char[].class ? field : null;
}
/**
* Returns the offset of the provided field, or {@code -1} if {@code sun.misc.Unsafe} is not
* available.

View File

@ -42,7 +42,6 @@ import static java.lang.Character.isSurrogatePair;
import static java.lang.Character.toCodePoint;
import java.nio.ByteBuffer;
import java.util.Arrays;
/**
* A set of low-level, high-performance static utility methods related
@ -87,7 +86,9 @@ final class Utf8 {
* delegate for which all methods are delegated directly to.
*/
private static final Processor processor =
UnsafeProcessor.isAvailable() ? new UnsafeProcessor() : new SafeProcessor();
(UnsafeProcessor.isAvailable() && !Android.isOnAndroidDevice())
? new UnsafeProcessor()
: new SafeProcessor();
/**
* A mask used when performing unsafe reads to determine if a long value contains any non-ASCII

View File

@ -210,8 +210,8 @@ public class AbstractMessageTest extends TestCase {
new TestUtil.ReflectionTester(TestAllTypes.getDescriptor(), null);
TestUtil.ReflectionTester extensionsReflectionTester =
new TestUtil.ReflectionTester(TestAllExtensions.getDescriptor(),
TestUtil.getExtensionRegistry());
new TestUtil.ReflectionTester(
TestAllExtensions.getDescriptor(), TestUtil.getFullExtensionRegistry());
public void testClear() throws Exception {
AbstractMessageWrapper message =

View File

@ -299,20 +299,22 @@ public class BooleanArrayListTest extends TestCase {
}
public void testRemoveEndOfCapacity() {
BooleanList toRemove = BooleanArrayList.emptyList().mutableCopyWithCapacity(1);
BooleanList toRemove =
BooleanArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addBoolean(true);
toRemove.remove(0);
assertEquals(0, toRemove.size());
}
public void testSublistRemoveEndOfCapacity() {
BooleanList toRemove = BooleanArrayList.emptyList().mutableCopyWithCapacity(1);
BooleanList toRemove =
BooleanArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addBoolean(true);
toRemove.subList(0, 1).clear();
assertEquals(0, toRemove.size());
}
private void assertImmutable(BooleanArrayList list) {
private void assertImmutable(BooleanList list) {
try {
list.add(true);

View File

@ -41,6 +41,7 @@ import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
@ -86,6 +87,40 @@ public class ByteStringTest extends TestCase {
return left.length == right.length && isArrayRange(left, right, 0, left.length);
}
public void testCompare_equalByteStrings_compareEqual() throws Exception {
byte[] referenceBytes = getTestBytes();
ByteString string1 = ByteString.copyFrom(referenceBytes);
ByteString string2 = ByteString.copyFrom(referenceBytes);
assertEquals(
"ByteString instances containing the same data must compare equal.",
0,
ByteString.unsignedLexicographicalComparator().compare(string1, string2));
}
public void testCompare_byteStringsSortLexicographically() throws Exception {
ByteString app = ByteString.copyFromUtf8("app");
ByteString apple = ByteString.copyFromUtf8("apple");
ByteString banana = ByteString.copyFromUtf8("banana");
Comparator<ByteString> comparator = ByteString.unsignedLexicographicalComparator();
assertTrue("ByteString(app) < ByteString(apple)", comparator.compare(app, apple) < 0);
assertTrue("ByteString(app) < ByteString(banana)", comparator.compare(app, banana) < 0);
assertTrue("ByteString(apple) < ByteString(banana)", comparator.compare(apple, banana) < 0);
}
public void testCompare_interpretsByteValuesAsUnsigned() throws Exception {
// Two's compliment of `-1` == 0b11111111 == 255
ByteString twoHundredFiftyFive = ByteString.copyFrom(new byte[] {-1});
// 0b00000001 == 1
ByteString one = ByteString.copyFrom(new byte[] {1});
assertTrue(
"ByteString comparison treats bytes as unsigned values",
ByteString.unsignedLexicographicalComparator().compare(one, twoHundredFiftyFive) < 0);
}
public void testSubstring_BeginIndex() {
byte[] bytes = getTestBytes();
ByteString substring = ByteString.copyFrom(bytes).substring(500);
@ -161,6 +196,34 @@ public class ByteStringTest extends TestCase {
byteString, byteStringAlt);
}
public void testCopyFrom_LengthTooBig() {
byte[] testBytes = getTestBytes(100);
try {
ByteString.copyFrom(testBytes, 0, 200);
fail("Should throw");
} catch (IndexOutOfBoundsException expected) {
}
try {
ByteString.copyFrom(testBytes, 99, 2);
fail();
} catch (IndexOutOfBoundsException expected) {
}
ByteBuffer buf = ByteBuffer.wrap(testBytes);
try {
ByteString.copyFrom(buf, 101);
fail();
} catch (IndexOutOfBoundsException expected) {
}
try {
ByteString.copyFrom(testBytes, -1, 10);
fail("Should throw");
} catch (IndexOutOfBoundsException expected) {
}
}
public void testCopyTo_TargetOffset() {
byte[] bytes = getTestBytes();
ByteString byteString = ByteString.copyFrom(bytes);
@ -761,6 +824,9 @@ public class ByteStringTest extends TestCase {
* Tests ByteString uses Arrays based byte copier when running under Hotstop VM.
*/
public void testByteArrayCopier() throws Exception {
if (Android.isOnAndroidDevice()) {
return;
}
Field field = ByteString.class.getDeclaredField("byteArrayCopier");
field.setAccessible(true);
Object byteArrayCopier = field.get(null);

View File

@ -62,22 +62,17 @@ public class DiscardUnknownFieldsTest {
}
private static void testProto2Message(Message message) throws Exception {
assertUnknownFieldsDefaultPreserved(message);
assertUnknownFieldsPreserved(message);
assertUnknownFieldsExplicitlyDiscarded(message);
assertReuseCodedInputStreamPreserve(message);
assertUnknownFieldsInUnknownFieldSetArePreserve(message);
}
private static void testProto3Message(Message message) throws Exception {
CodedInputStream.setProto3KeepUnknownsByDefaultForTest();
assertUnknownFieldsDefaultPreserved(message);
assertUnknownFieldsPreserved(message);
assertUnknownFieldsExplicitlyDiscarded(message);
assertReuseCodedInputStreamPreserve(message);
assertUnknownFieldsInUnknownFieldSetArePreserve(message);
CodedInputStream.setProto3DiscardUnknownsByDefaultForTest();
assertUnknownFieldsDefaultDiscarded(message);
assertUnknownFieldsExplicitlyDiscarded(message);
assertUnknownFieldsInUnknownFieldSetAreDiscarded(message);
}
private static void assertReuseCodedInputStreamPreserve(Message message) throws Exception {
@ -122,7 +117,7 @@ public class DiscardUnknownFieldsTest {
assertEquals(message.getClass().getName(), 0, built.getSerializedSize());
}
private static void assertUnknownFieldsDefaultPreserved(MessageLite message) throws Exception {
private static void assertUnknownFieldsPreserved(MessageLite message) throws Exception {
{
MessageLite parsed = message.getParserForType().parseFrom(payload);
assertEquals(message.getClass().getName(), payload, parsed.toByteString());
@ -134,18 +129,6 @@ public class DiscardUnknownFieldsTest {
}
}
private static void assertUnknownFieldsDefaultDiscarded(MessageLite message) throws Exception {
{
MessageLite parsed = message.getParserForType().parseFrom(payload);
assertEquals(message.getClass().getName(), 0, parsed.getSerializedSize());
}
{
MessageLite parsed = message.newBuilderForType().mergeFrom(payload).build();
assertEquals(message.getClass().getName(), 0, parsed.getSerializedSize());
}
}
private static void assertUnknownFieldsExplicitlyDiscarded(Message message) throws Exception {
Message parsed =
DiscardUnknownFieldsParser.wrap(message.getParserForType()).parseFrom(payload);

View File

@ -78,10 +78,10 @@ public class DoubleArrayListTest extends TestCase {
list.addAll(asList(1D, 2D, 3D, 4D));
Iterator<Double> iterator = list.iterator();
assertEquals(4, list.size());
assertEquals(1D, (double) list.get(0));
assertEquals(1D, (double) iterator.next());
assertEquals(1D, (double) list.get(0), 0.0);
assertEquals(1D, (double) iterator.next(), 0.0);
list.set(0, 1D);
assertEquals(2D, (double) iterator.next());
assertEquals(2D, (double) iterator.next(), 0.0);
list.remove(0);
try {
@ -102,9 +102,9 @@ public class DoubleArrayListTest extends TestCase {
}
public void testGet() {
assertEquals(1D, (double) TERTIARY_LIST.get(0));
assertEquals(2D, (double) TERTIARY_LIST.get(1));
assertEquals(3D, (double) TERTIARY_LIST.get(2));
assertEquals(1D, (double) TERTIARY_LIST.get(0), 0.0);
assertEquals(2D, (double) TERTIARY_LIST.get(1), 0.0);
assertEquals(3D, (double) TERTIARY_LIST.get(2), 0.0);
try {
TERTIARY_LIST.get(-1);
@ -122,9 +122,9 @@ public class DoubleArrayListTest extends TestCase {
}
public void testGetDouble() {
assertEquals(1D, TERTIARY_LIST.getDouble(0));
assertEquals(2D, TERTIARY_LIST.getDouble(1));
assertEquals(3D, TERTIARY_LIST.getDouble(2));
assertEquals(1D, TERTIARY_LIST.getDouble(0), 0.0);
assertEquals(2D, TERTIARY_LIST.getDouble(1), 0.0);
assertEquals(3D, TERTIARY_LIST.getDouble(2), 0.0);
try {
TERTIARY_LIST.get(-1);
@ -163,11 +163,11 @@ public class DoubleArrayListTest extends TestCase {
list.addDouble(2);
list.addDouble(4);
assertEquals(2D, (double) list.set(0, 3D));
assertEquals(3D, list.getDouble(0));
assertEquals(2D, (double) list.set(0, 3D), 0.0);
assertEquals(3D, list.getDouble(0), 0.0);
assertEquals(4D, (double) list.set(1, 0D));
assertEquals(0D, list.getDouble(1));
assertEquals(4D, (double) list.set(1, 0D), 0.0);
assertEquals(0D, list.getDouble(1), 0.0);
try {
list.set(-1, 0D);
@ -188,11 +188,11 @@ public class DoubleArrayListTest extends TestCase {
list.addDouble(1);
list.addDouble(3);
assertEquals(1D, list.setDouble(0, 0));
assertEquals(0D, list.getDouble(0));
assertEquals(1D, list.setDouble(0, 0), 0.0);
assertEquals(0D, list.getDouble(0), 0.0);
assertEquals(3D, list.setDouble(1, 0));
assertEquals(0D, list.getDouble(1));
assertEquals(3D, list.setDouble(1, 0), 0.0);
assertEquals(0D, list.getDouble(1), 0.0);
try {
list.setDouble(-1, 0);
@ -257,8 +257,8 @@ public class DoubleArrayListTest extends TestCase {
assertTrue(list.addAll(Collections.singleton(1D)));
assertEquals(1, list.size());
assertEquals(1D, (double) list.get(0));
assertEquals(1D, list.getDouble(0));
assertEquals(1D, (double) list.get(0), 0.0);
assertEquals(1D, list.getDouble(0), 0.0);
assertTrue(list.addAll(asList(2D, 3D, 4D, 5D, 6D)));
assertEquals(asList(1D, 2D, 3D, 4D, 5D, 6D), list);
@ -272,7 +272,7 @@ public class DoubleArrayListTest extends TestCase {
public void testRemove() {
list.addAll(TERTIARY_LIST);
assertEquals(1D, (double) list.remove(0));
assertEquals(1D, (double) list.remove(0), 0.0);
assertEquals(asList(2D, 3D), list);
assertTrue(list.remove(Double.valueOf(3)));
@ -281,7 +281,7 @@ public class DoubleArrayListTest extends TestCase {
assertFalse(list.remove(Double.valueOf(3)));
assertEquals(asList(2D), list);
assertEquals(2D, (double) list.remove(0));
assertEquals(2D, (double) list.remove(0), 0.0);
assertEquals(asList(), list);
try {
@ -299,20 +299,22 @@ public class DoubleArrayListTest extends TestCase {
}
public void testRemoveEndOfCapacity() {
DoubleList toRemove = DoubleArrayList.emptyList().mutableCopyWithCapacity(1);
DoubleList toRemove =
DoubleArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addDouble(3);
toRemove.remove(0);
assertEquals(0, toRemove.size());
}
public void testSublistRemoveEndOfCapacity() {
DoubleList toRemove = DoubleArrayList.emptyList().mutableCopyWithCapacity(1);
DoubleList toRemove =
DoubleArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addDouble(3);
toRemove.subList(0, 1).clear();
assertEquals(0, toRemove.size());
}
private void assertImmutable(DoubleArrayList list) {
private void assertImmutable(DoubleList list) {
if (list.contains(1D)) {
throw new RuntimeException("Cannot test the immutability of lists that contain 1.");
}

View File

@ -51,8 +51,8 @@ public class DynamicMessageTest extends TestCase {
new TestUtil.ReflectionTester(TestAllTypes.getDescriptor(), null);
TestUtil.ReflectionTester extensionsReflectionTester =
new TestUtil.ReflectionTester(TestAllExtensions.getDescriptor(),
TestUtil.getExtensionRegistry());
new TestUtil.ReflectionTester(
TestAllExtensions.getDescriptor(), TestUtil.getFullExtensionRegistry());
TestUtil.ReflectionTester packedReflectionTester =
new TestUtil.ReflectionTester(TestPackedTypes.getDescriptor(), null);
@ -194,9 +194,9 @@ public class DynamicMessageTest extends TestCase {
public void testDynamicMessageExtensionParsing() throws Exception {
ByteString rawBytes = TestUtil.getAllExtensionsSet().toByteString();
Message message = DynamicMessage.parseFrom(
TestAllExtensions.getDescriptor(), rawBytes,
TestUtil.getExtensionRegistry());
Message message =
DynamicMessage.parseFrom(
TestAllExtensions.getDescriptor(), rawBytes, TestUtil.getFullExtensionRegistry());
extensionsReflectionTester.assertAllFieldsSetViaReflection(message);
// Test Parser interface.

View File

@ -78,10 +78,10 @@ public class FloatArrayListTest extends TestCase {
list.addAll(asList(1F, 2F, 3F, 4F));
Iterator<Float> iterator = list.iterator();
assertEquals(4, list.size());
assertEquals(1F, (float) list.get(0));
assertEquals(1F, (float) iterator.next());
assertEquals(1F, (float) list.get(0), 0.0f);
assertEquals(1F, (float) iterator.next(), 0.0f);
list.set(0, 1F);
assertEquals(2F, (float) iterator.next());
assertEquals(2F, (float) iterator.next(), 0.0f);
list.remove(0);
try {
@ -102,9 +102,9 @@ public class FloatArrayListTest extends TestCase {
}
public void testGet() {
assertEquals(1F, (float) TERTIARY_LIST.get(0));
assertEquals(2F, (float) TERTIARY_LIST.get(1));
assertEquals(3F, (float) TERTIARY_LIST.get(2));
assertEquals(1F, (float) TERTIARY_LIST.get(0), 0.0f);
assertEquals(2F, (float) TERTIARY_LIST.get(1), 0.0f);
assertEquals(3F, (float) TERTIARY_LIST.get(2), 0.0f);
try {
TERTIARY_LIST.get(-1);
@ -122,9 +122,9 @@ public class FloatArrayListTest extends TestCase {
}
public void testGetFloat() {
assertEquals(1F, TERTIARY_LIST.getFloat(0));
assertEquals(2F, TERTIARY_LIST.getFloat(1));
assertEquals(3F, TERTIARY_LIST.getFloat(2));
assertEquals(1F, TERTIARY_LIST.getFloat(0), 0.0f);
assertEquals(2F, TERTIARY_LIST.getFloat(1), 0.0f);
assertEquals(3F, TERTIARY_LIST.getFloat(2), 0.0f);
try {
TERTIARY_LIST.get(-1);
@ -163,11 +163,11 @@ public class FloatArrayListTest extends TestCase {
list.addFloat(2);
list.addFloat(4);
assertEquals(2F, (float) list.set(0, 3F));
assertEquals(3F, list.getFloat(0));
assertEquals(2F, (float) list.set(0, 3F), 0.0f);
assertEquals(3F, list.getFloat(0), 0.0f);
assertEquals(4F, (float) list.set(1, 0F));
assertEquals(0F, list.getFloat(1));
assertEquals(4F, (float) list.set(1, 0F), 0.0f);
assertEquals(0F, list.getFloat(1), 0.0f);
try {
list.set(-1, 0F);
@ -188,11 +188,11 @@ public class FloatArrayListTest extends TestCase {
list.addFloat(1);
list.addFloat(3);
assertEquals(1F, list.setFloat(0, 0));
assertEquals(0F, list.getFloat(0));
assertEquals(1F, list.setFloat(0, 0), 0.0f);
assertEquals(0F, list.getFloat(0), 0.0f);
assertEquals(3F, list.setFloat(1, 0));
assertEquals(0F, list.getFloat(1));
assertEquals(3F, list.setFloat(1, 0), 0.0f);
assertEquals(0F, list.getFloat(1), 0.0f);
try {
list.setFloat(-1, 0);
@ -257,8 +257,8 @@ public class FloatArrayListTest extends TestCase {
assertTrue(list.addAll(Collections.singleton(1F)));
assertEquals(1, list.size());
assertEquals(1F, (float) list.get(0));
assertEquals(1F, list.getFloat(0));
assertEquals(1F, (float) list.get(0), 0.0f);
assertEquals(1F, list.getFloat(0), 0.0f);
assertTrue(list.addAll(asList(2F, 3F, 4F, 5F, 6F)));
assertEquals(asList(1F, 2F, 3F, 4F, 5F, 6F), list);
@ -272,7 +272,7 @@ public class FloatArrayListTest extends TestCase {
public void testRemove() {
list.addAll(TERTIARY_LIST);
assertEquals(1F, (float) list.remove(0));
assertEquals(1F, (float) list.remove(0), 0.0f);
assertEquals(asList(2F, 3F), list);
assertTrue(list.remove(Float.valueOf(3)));
@ -281,7 +281,7 @@ public class FloatArrayListTest extends TestCase {
assertFalse(list.remove(Float.valueOf(3)));
assertEquals(asList(2F), list);
assertEquals(2F, (float) list.remove(0));
assertEquals(2F, (float) list.remove(0), 0.0f);
assertEquals(asList(), list);
try {
@ -299,20 +299,22 @@ public class FloatArrayListTest extends TestCase {
}
public void testRemoveEndOfCapacity() {
FloatList toRemove = FloatArrayList.emptyList().mutableCopyWithCapacity(1);
FloatList toRemove =
FloatArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addFloat(3);
toRemove.remove(0);
assertEquals(0, toRemove.size());
}
public void testSublistRemoveEndOfCapacity() {
FloatList toRemove = FloatArrayList.emptyList().mutableCopyWithCapacity(1);
FloatList toRemove =
FloatArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addFloat(3);
toRemove.subList(0, 1).clear();
assertEquals(0, toRemove.size());
}
private void assertImmutable(FloatArrayList list) {
private void assertImmutable(FloatList list) {
if (list.contains(1F)) {
throw new RuntimeException("Cannot test the immutability of lists that contain 1.");
}

View File

@ -590,8 +590,8 @@ public class GeneratedMessageTest extends TestCase {
// Extensions.
TestUtil.ReflectionTester extensionsReflectionTester =
new TestUtil.ReflectionTester(TestAllExtensions.getDescriptor(),
TestUtil.getExtensionRegistry());
new TestUtil.ReflectionTester(
TestAllExtensions.getDescriptor(), TestUtil.getFullExtensionRegistry());
public void testExtensionMessageOrBuilder() throws Exception {
TestAllExtensions.Builder builder = TestAllExtensions.newBuilder();

View File

@ -299,20 +299,22 @@ public class IntArrayListTest extends TestCase {
}
public void testRemoveEndOfCapacity() {
IntList toRemove = IntArrayList.emptyList().mutableCopyWithCapacity(1);
IntList toRemove =
IntArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addInt(3);
toRemove.remove(0);
assertEquals(0, toRemove.size());
}
public void testSublistRemoveEndOfCapacity() {
IntList toRemove = IntArrayList.emptyList().mutableCopyWithCapacity(1);
IntList toRemove =
IntArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addInt(3);
toRemove.subList(0, 1).clear();
assertEquals(0, toRemove.size());
}
private void assertImmutable(IntArrayList list) {
private void assertImmutable(IntList list) {
if (list.contains(1)) {
throw new RuntimeException("Cannot test the immutability of lists that contain 1.");
}

View File

@ -57,6 +57,8 @@ import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.TestOneofEquals;
import protobuf_unittest.lite_equals_and_hash.LiteEqualsAndHash.TestRecursiveOneof;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.ArrayList;
@ -2378,4 +2380,63 @@ public class LiteTest extends TestCase {
} catch (NullPointerException expected) {
}
}
public void testSerializeToOutputStreamThrowsIOException() {
try {
TestAllTypesLite.newBuilder()
.setOptionalBytes(ByteString.copyFromUtf8("hello"))
.build()
.writeTo(
new OutputStream() {
@Override
public void write(int b) throws IOException {
throw new IOException();
}
});
fail();
} catch (IOException expected) {
}
}
public void testUnpairedSurrogatesReplacedByQuestionMark() throws InvalidProtocolBufferException {
String testString = "foo \ud83d bar";
String expectedString = "foo ? bar";
TestAllTypesLite testMessage =
TestAllTypesLite.newBuilder().setOptionalString(testString).build();
ByteString serializedMessage = testMessage.toByteString();
// Behavior is compatible with String.getBytes("UTF-8"), which replaces
// unpaired surrogates with a question mark.
TestAllTypesLite parsedMessage = TestAllTypesLite.parseFrom(serializedMessage);
assertEquals(expectedString, parsedMessage.getOptionalString());
// Conversion happens during serialization.
ByteString expectedBytes = ByteString.copyFromUtf8(expectedString);
assertTrue(
String.format(
"Expected serializedMessage (%s) to contain \"%s\" (%s).",
encodeHex(serializedMessage), expectedString, encodeHex(expectedBytes)),
contains(serializedMessage, expectedBytes));
}
private String encodeHex(ByteString bytes) {
String hexDigits = "0123456789abcdef";
StringBuilder stringBuilder = new StringBuilder(bytes.size() * 2);
for (byte b : bytes) {
stringBuilder.append(hexDigits.charAt((b & 0xf0) >> 4));
stringBuilder.append(hexDigits.charAt(b & 0x0f));
}
return stringBuilder.toString();
}
private boolean contains(ByteString a, ByteString b) {
for (int i = 0; i <= a.size() - b.size(); ++i) {
if (a.substring(i, i + b.size()).equals(b)) {
return true;
}
}
return false;
}
}

View File

@ -299,20 +299,22 @@ public class LongArrayListTest extends TestCase {
}
public void testRemoveEndOfCapacity() {
LongList toRemove = LongArrayList.emptyList().mutableCopyWithCapacity(1);
LongList toRemove =
LongArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addLong(3);
toRemove.remove(0);
assertEquals(0, toRemove.size());
}
public void testSublistRemoveEndOfCapacity() {
LongList toRemove = LongArrayList.emptyList().mutableCopyWithCapacity(1);
LongList toRemove =
LongArrayList.emptyList().mutableCopyWithCapacity(1);
toRemove.addLong(3);
toRemove.subList(0, 1).clear();
assertEquals(0, toRemove.size());
}
private void assertImmutable(LongArrayList list) {
private void assertImmutable(LongList list) {
if (list.contains(1L)) {
throw new RuntimeException("Cannot test the immutability of lists that contain 1.");
}

View File

@ -42,33 +42,33 @@ import junit.framework.TestCase;
* Tests for {@link ProtobufArrayList}.
*/
public class ProtobufArrayListTest extends TestCase {
private static final ProtobufArrayList<Integer> UNARY_LIST = newImmutableProtoArrayList(1);
private static final ProtobufArrayList<Integer> TERTIARY_LIST =
newImmutableProtoArrayList(1, 2, 3);
private ProtobufArrayList<Integer> list;
@Override
protected void setUp() throws Exception {
list = new ProtobufArrayList<Integer>();
}
public void testEmptyListReturnsSameInstance() {
assertSame(ProtobufArrayList.emptyList(), ProtobufArrayList.emptyList());
}
public void testEmptyListIsImmutable() {
assertImmutable(ProtobufArrayList.<Integer>emptyList());
}
public void testModificationWithIteration() {
list.addAll(asList(1, 2, 3, 4));
Iterator<Integer> iterator = list.iterator();
assertEquals(4, list.size());
assertEquals(1, (int) list.get(0));
assertEquals(1, (int) iterator.next());
list.remove(0);
try {
iterator.next();
@ -76,7 +76,7 @@ public class ProtobufArrayListTest extends TestCase {
} catch (ConcurrentModificationException e) {
// expected
}
iterator = list.iterator();
list.set(0, 1);
try {
@ -85,7 +85,7 @@ public class ProtobufArrayListTest extends TestCase {
} catch (ConcurrentModificationException e) {
// expected
}
iterator = list.iterator();
list.add(0, 0);
try {
@ -95,7 +95,7 @@ public class ProtobufArrayListTest extends TestCase {
// expected
}
}
public void testMakeImmutable() {
list.add(2);
list.add(4);
@ -104,107 +104,213 @@ public class ProtobufArrayListTest extends TestCase {
list.makeImmutable();
assertImmutable(list);
}
public void testRemove() {
list.add(2);
list.add(4);
list.add(6);
list.addAll(TERTIARY_LIST);
assertEquals(1, (int) list.remove(0));
assertEquals(asList(2, 3), list);
list.remove(1);
assertEquals(asList(2, 6), list);
list.remove(1);
assertTrue(list.remove(Integer.valueOf(3)));
assertEquals(asList(2), list);
list.remove(0);
assertFalse(list.remove(Integer.valueOf(3)));
assertEquals(asList(2), list);
assertEquals(2, (int) list.remove(0));
assertEquals(asList(), list);
try {
list.remove(-1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
list.remove(0);
} catch (IndexOutOfBoundsException e) {
// expected
}
}
public void testGet() {
list.add(2);
list.add(6);
assertEquals(2, (int) list.get(0));
assertEquals(6, (int) list.get(1));
assertEquals(1, (int) TERTIARY_LIST.get(0));
assertEquals(2, (int) TERTIARY_LIST.get(1));
assertEquals(3, (int) TERTIARY_LIST.get(2));
try {
TERTIARY_LIST.get(-1);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
TERTIARY_LIST.get(3);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
}
public void testSet() {
list.add(2);
list.add(6);
list.set(0, 1);
list.add(4);
assertEquals(2, (int) list.set(0, 3));
assertEquals(3, (int) list.get(0));
assertEquals(4, (int) list.set(1, 0));
assertEquals(0, (int) list.get(1));
try {
list.set(-1, 0);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
list.set(2, 0);
fail();
} catch (IndexOutOfBoundsException e) {
// expected
}
}
public void testAdd() {
assertEquals(0, list.size());
assertTrue(list.add(2));
assertEquals(asList(2), list);
assertTrue(list.add(3));
list.add(0, 4);
assertEquals(asList(4, 2, 3), list);
list.add(0, 1);
list.add(0, 0);
// Force a resize by getting up to 11 elements.
for (int i = 0; i < 6; i++) {
list.add(Integer.valueOf(5 + i));
}
assertEquals(asList(0, 1, 4, 2, 3, 5, 6, 7, 8, 9, 10), list);
try {
list.add(-1, 5);
} catch (IndexOutOfBoundsException e) {
// expected
}
try {
list.add(4, 5);
} catch (IndexOutOfBoundsException e) {
// expected
}
}
public void testAddAll() {
assertEquals(0, list.size());
assertTrue(list.addAll(Collections.singleton(1)));
assertEquals(1, list.size());
assertEquals(1, (int) list.get(0));
list.set(1, 2);
assertEquals(2, (int) list.get(1));
assertTrue(list.addAll(asList(2, 3, 4, 5, 6)));
assertEquals(asList(1, 2, 3, 4, 5, 6), list);
assertTrue(list.addAll(TERTIARY_LIST));
assertEquals(asList(1, 2, 3, 4, 5, 6, 1, 2, 3), list);
assertFalse(list.addAll(Collections.<Integer>emptyList()));
assertFalse(list.addAll(IntArrayList.emptyList()));
}
public void testSize() {
assertEquals(0, ProtobufArrayList.emptyList().size());
assertEquals(1, UNARY_LIST.size());
assertEquals(3, TERTIARY_LIST.size());
list.add(3);
list.add(4);
list.add(6);
list.add(8);
assertEquals(4, list.size());
list.remove(0);
assertEquals(3, list.size());
list.add(17);
assertEquals(4, list.size());
}
private void assertImmutable(List<Integer> list) {
if (list.contains(1)) {
throw new RuntimeException("Cannot test the immutability of lists that contain 1.");
}
try {
list.add(1);
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.add(0, 1);
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.addAll(Collections.<Integer>emptyList());
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.addAll(Collections.singletonList(1));
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.addAll(new ProtobufArrayList<Integer>());
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.addAll(UNARY_LIST);
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.addAll(0, Collections.singleton(1));
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.addAll(0, UNARY_LIST);
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.addAll(0, Collections.<Integer>emptyList());
fail();
} catch (UnsupportedOperationException e) {
// expected
}
}
try {
list.clear();
@ -219,56 +325,56 @@ public class ProtobufArrayListTest extends TestCase {
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.remove(new Object());
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.removeAll(Collections.emptyList());
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.removeAll(Collections.singleton(1));
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.removeAll(UNARY_LIST);
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.retainAll(Collections.emptyList());
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.retainAll(Collections.singleton(1));
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.retainAll(UNARY_LIST);
fail();
} catch (UnsupportedOperationException e) {
// expected
}
try {
list.set(0, 0);
fail();
@ -276,7 +382,7 @@ public class ProtobufArrayListTest extends TestCase {
// expected
}
}
private static ProtobufArrayList<Integer> newImmutableProtoArrayList(int... elements) {
ProtobufArrayList<Integer> list = new ProtobufArrayList<Integer>();
for (int element : elements) {

View File

@ -130,8 +130,6 @@ import static protobuf_unittest.UnittestProto.defaultFixed64Extension;
import static protobuf_unittest.UnittestProto.defaultFloatExtension;
import static protobuf_unittest.UnittestProto.defaultForeignEnumExtension;
import static protobuf_unittest.UnittestProto.defaultImportEnumExtension;
// The static imports are to avoid 100+ char lines. The following is roughly equivalent to
// import static protobuf_unittest.UnittestProto.*;
import static protobuf_unittest.UnittestProto.defaultInt32Extension;
import static protobuf_unittest.UnittestProto.defaultInt64Extension;
import static protobuf_unittest.UnittestProto.defaultNestedEnumExtension;
@ -263,12 +261,14 @@ public final class TestUtil {
return ByteString.copyFrom(str.getBytes(Internal.UTF_8));
}
// BEGIN FULL-RUNTIME
/**
* Dirties the message by resetting the momoized serialized size.
*/
public static void resetMemoizedSize(AbstractMessage message) {
message.memoizedSize = -1;
}
// END FULL-RUNTIME
/**
* Get a {@code TestAllTypes} with all fields set as they would be by
@ -1201,17 +1201,29 @@ public final class TestUtil {
* Get an unmodifiable {@link ExtensionRegistry} containing all the
* extensions of {@code TestAllExtensions}.
*/
public static ExtensionRegistry getExtensionRegistry() {
public static ExtensionRegistryLite getExtensionRegistry() {
ExtensionRegistryLite registry = ExtensionRegistryLite.newInstance();
registerAllExtensions(registry);
return registry.getUnmodifiable();
}
// BEGIN FULL-RUNTIME
/**
* Get an unmodifiable {@link ExtensionRegistry} containing all the
* extensions of {@code TestAllExtensions}.
*/
public static ExtensionRegistry getFullExtensionRegistry() {
ExtensionRegistry registry = ExtensionRegistry.newInstance();
registerAllExtensions(registry);
return registry.getUnmodifiable();
}
// END FULL-RUNTIME
/**
* Register all of {@code TestAllExtensions}'s extensions with the
* given {@link ExtensionRegistry}.
*/
public static void registerAllExtensions(ExtensionRegistry registry) {
public static void registerAllExtensions(ExtensionRegistryLite registry) {
UnittestProto.registerAllExtensions(registry);
TestUtilLite.registerAllExtensionsLite(registry);
}
@ -2634,6 +2646,7 @@ public final class TestUtil {
}
// =================================================================
// BEGIN FULL-RUNTIME
/**
* Performs the same things that the methods of {@code TestUtil} do, but
@ -3819,6 +3832,16 @@ public final class TestUtil {
"Couldn't read file: " + fullPath.getPath(), e);
}
}
// END FULL-RUNTIME
private static ByteString readBytesFromResource(String name) {
try {
return ByteString.copyFrom(
com.google.common.io.ByteStreams.toByteArray(TestUtil.class.getResourceAsStream(name)));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Get the bytes of the "golden message". This is a serialized TestAllTypes
@ -3829,7 +3852,7 @@ public final class TestUtil {
*/
public static ByteString getGoldenMessage() {
if (goldenMessage == null) {
goldenMessage = readBytesFromFile("golden_message_oneof_implemented");
goldenMessage = readBytesFromResource("/google/protobuf/testdata/golden_message_oneof_implemented");
}
return goldenMessage;
}
@ -3846,12 +3869,13 @@ public final class TestUtil {
public static ByteString getGoldenPackedFieldsMessage() {
if (goldenPackedFieldsMessage == null) {
goldenPackedFieldsMessage =
readBytesFromFile("golden_packed_fields_message");
readBytesFromResource("/google/protobuf/testdata/golden_packed_fields_message");
}
return goldenPackedFieldsMessage;
}
private static ByteString goldenPackedFieldsMessage = null;
// BEGIN FULL-RUNTIME
/**
* Mock implementation of {@link GeneratedMessage.BuilderParent} for testing.
*
@ -3871,4 +3895,5 @@ public final class TestUtil {
return invalidations;
}
}
// END FULL-RUNTIME
}

View File

@ -382,17 +382,14 @@ public class TextFormatTest extends TestCase {
public void testMergeExtensions() throws Exception {
TestAllExtensions.Builder builder = TestAllExtensions.newBuilder();
TextFormat.merge(allExtensionsSetText,
TestUtil.getExtensionRegistry(),
builder);
TextFormat.merge(allExtensionsSetText, TestUtil.getFullExtensionRegistry(), builder);
TestUtil.assertAllExtensionsSet(builder.build());
}
public void testParseExtensions() throws Exception {
TestUtil.assertAllExtensionsSet(
TextFormat.parse(allExtensionsSetText,
TestUtil.getExtensionRegistry(),
TestAllExtensions.class));
TextFormat.parse(
allExtensionsSetText, TestUtil.getFullExtensionRegistry(), TestAllExtensions.class));
}
public void testMergeAndParseCompatibility() throws Exception {
@ -523,7 +520,7 @@ public class TextFormatTest extends TestCase {
// Test merge().
TestAllTypes.Builder builder = TestAllTypes.newBuilder();
try {
TextFormat.merge(text, TestUtil.getExtensionRegistry(), builder);
TextFormat.merge(text, TestUtil.getFullExtensionRegistry(), builder);
fail("Expected parse exception.");
} catch (TextFormat.ParseException e) {
assertEquals(error, e.getMessage());
@ -531,8 +528,7 @@ public class TextFormatTest extends TestCase {
// Test parse().
try {
TextFormat.parse(
text, TestUtil.getExtensionRegistry(), TestAllTypes.class);
TextFormat.parse(text, TestUtil.getFullExtensionRegistry(), TestAllTypes.class);
fail("Expected parse exception.");
} catch (TextFormat.ParseException e) {
assertEquals(error, e.getMessage());
@ -544,8 +540,7 @@ public class TextFormatTest extends TestCase {
String text) {
TestAllTypes.Builder builder = TestAllTypes.newBuilder();
try {
parserWithOverwriteForbidden.merge(
text, TestUtil.getExtensionRegistry(), builder);
parserWithOverwriteForbidden.merge(text, TestUtil.getFullExtensionRegistry(), builder);
fail("Expected parse exception.");
} catch (TextFormat.ParseException e) {
assertEquals(error, e.getMessage());
@ -555,8 +550,7 @@ public class TextFormatTest extends TestCase {
private TestAllTypes assertParseSuccessWithOverwriteForbidden(
String text) throws TextFormat.ParseException {
TestAllTypes.Builder builder = TestAllTypes.newBuilder();
parserWithOverwriteForbidden.merge(
text, TestUtil.getExtensionRegistry(), builder);
parserWithOverwriteForbidden.merge(text, TestUtil.getFullExtensionRegistry(), builder);
return builder.build();
}
@ -1118,8 +1112,7 @@ public class TextFormatTest extends TestCase {
String input = "foo_string: \"stringvalue\" foo_int: 123";
TestOneof2.Builder builder = TestOneof2.newBuilder();
try {
parserWithOverwriteForbidden.merge(
input, TestUtil.getExtensionRegistry(), builder);
parserWithOverwriteForbidden.merge(input, TestUtil.getFullExtensionRegistry(), builder);
fail("Expected parse exception.");
} catch (TextFormat.ParseException e) {
assertEquals("1:36: Field \"protobuf_unittest.TestOneof2.foo_int\""
@ -1131,7 +1124,7 @@ public class TextFormatTest extends TestCase {
public void testOneofOverwriteAllowed() throws Exception {
String input = "foo_string: \"stringvalue\" foo_int: 123";
TestOneof2.Builder builder = TestOneof2.newBuilder();
defaultParser.merge(input, TestUtil.getExtensionRegistry(), builder);
defaultParser.merge(input, TestUtil.getFullExtensionRegistry(), builder);
// Only the last value sticks.
TestOneof2 oneof = builder.build();
assertFalse(oneof.hasFooString());

View File

@ -132,7 +132,7 @@ public class WireFormatTest extends TestCase {
TestAllTypes message = TestUtil.getAllSet();
ByteString rawBytes = message.toByteString();
ExtensionRegistry registry = TestUtil.getExtensionRegistry();
ExtensionRegistryLite registry = TestUtil.getExtensionRegistry();
TestAllExtensions message2 =
TestAllExtensions.parseFrom(rawBytes, registry);
@ -145,7 +145,7 @@ public class WireFormatTest extends TestCase {
TestPackedExtensions message = TestUtil.getPackedExtensionsSet();
ByteString rawBytes = message.toByteString();
ExtensionRegistry registry = TestUtil.getExtensionRegistry();
ExtensionRegistryLite registry = TestUtil.getExtensionRegistry();
TestPackedExtensions message2 =
TestPackedExtensions.parseFrom(rawBytes, registry);

View File

@ -36,8 +36,6 @@ syntax = "proto2";
package protobuf_unittest;
option optimize_for = LITE_RUNTIME;
message LazyMessageLite {
optional int32 num = 1;
optional int32 num_with_default = 2 [default = 421];

View File

@ -34,7 +34,6 @@ syntax = "proto2";
package protobuf_unittest.lite_equals_and_hash;
option optimize_for = LITE_RUNTIME;
message TestOneofEquals {
oneof oneof_field {

View File

@ -30,10 +30,9 @@
syntax = "proto3";
package map_lite_test;
package map_test;
option optimize_for = LITE_RUNTIME;
option java_package = "map_lite_test";
option java_package = "map_test";
option java_outer_classname = "MapTestProto";
message TestMap {

View File

@ -38,7 +38,6 @@ syntax = "proto2";
package protobuf_unittest;
option optimize_for = LITE_RUNTIME;
import "com/google/protobuf/non_nested_extension_lite.proto";

View File

@ -36,7 +36,6 @@ syntax = "proto2";
package protobuf_unittest;
option optimize_for = LITE_RUNTIME;
message MessageLiteToBeExtended {
extensions 1 to max;

View File

@ -61,6 +61,9 @@ public final class Durations {
public static final Duration MAX_VALUE =
Duration.newBuilder().setSeconds(DURATION_SECONDS_MAX).setNanos(999999999).build();
/** A constant holding the duration of zero. */
public static final Duration ZERO = Duration.newBuilder().setSeconds(0L).setNanos(0).build();
private Durations() {}
private static final Comparator<Duration> COMPARATOR =

View File

@ -249,12 +249,9 @@ final class FieldMaskTree {
continue;
}
String childPath = path.isEmpty() ? entry.getKey() : path + "." + entry.getKey();
merge(
entry.getValue(),
childPath,
(Message) source.getField(field),
destination.getFieldBuilder(field),
options);
Message.Builder childBuilder = ((Message) destination.getField(field)).toBuilder();
merge(entry.getValue(), childPath, (Message) source.getField(field), childBuilder, options);
destination.setField(field, childBuilder.buildPartial());
continue;
}
if (field.isRepeated()) {
@ -275,7 +272,12 @@ final class FieldMaskTree {
}
} else {
if (source.hasField(field)) {
destination.getFieldBuilder(field).mergeFrom((Message) source.getField(field));
destination.setField(
field,
((Message) destination.getField(field))
.toBuilder()
.mergeFrom((Message) source.getField(field))
.build());
}
}
} else {

View File

@ -235,7 +235,7 @@ public class FieldMaskUtil {
/**
* Converts a FieldMask to its canonical form. In the canonical form of a
* FieldMask, all field paths are sorted alphabetically and redundant field
* paths are moved.
* paths are removed.
*/
public static FieldMask normalize(FieldMask mask) {
return new FieldMaskTree(mask).toFieldMask();

View File

@ -610,7 +610,7 @@ public class JsonFormat {
private final CharSequence blankOrNewLine;
private static class GsonHolder {
private static final Gson DEFAULT_GSON = new GsonBuilder().disableHtmlEscaping().create();
private static final Gson DEFAULT_GSON = new GsonBuilder().create();
}
PrinterImpl(

View File

@ -30,9 +30,14 @@
package com.google.protobuf.util;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.Message;
import com.google.protobuf.UninitializedMessageException;
import protobuf_unittest.UnittestProto.NestedTestAllTypes;
import protobuf_unittest.UnittestProto.TestAllTypes;
import protobuf_unittest.UnittestProto.TestAllTypes.NestedMessage;
import protobuf_unittest.UnittestProto.TestRequired;
import protobuf_unittest.UnittestProto.TestRequiredMessage;
import junit.framework.TestCase;
public class FieldMaskTreeTest extends TestCase {
@ -90,8 +95,68 @@ public class FieldMaskTreeTest extends TestCase {
tree.intersectFieldPath("bar", result);
assertEquals("bar.baz,bar.quz,foo", result.toString());
}
public void testMerge() throws Exception {
testMergeImpl(true);
testMergeImpl(false);
testMergeRequire(false);
testMergeRequire(true);
}
private void merge(
FieldMaskTree tree,
Message source,
Message.Builder builder,
FieldMaskUtil.MergeOptions options,
boolean useDynamicMessage)
throws Exception {
if (useDynamicMessage) {
Message.Builder newBuilder =
DynamicMessage.newBuilder(source.getDescriptorForType())
.mergeFrom(builder.buildPartial().toByteArray());
tree.merge(
DynamicMessage.newBuilder(source.getDescriptorForType())
.mergeFrom(source.toByteArray())
.build(),
newBuilder,
options);
builder.clear();
builder.mergeFrom(newBuilder.buildPartial());
} else {
tree.merge(source, builder, options);
}
}
private void testMergeRequire(boolean useDynamicMessage) throws Exception {
TestRequired value = TestRequired.newBuilder().setA(4321).setB(8765).setC(233333).build();
TestRequiredMessage source = TestRequiredMessage.newBuilder().setRequiredMessage(value).build();
FieldMaskUtil.MergeOptions options = new FieldMaskUtil.MergeOptions();
TestRequiredMessage.Builder builder = TestRequiredMessage.newBuilder();
merge(
new FieldMaskTree().addFieldPath("required_message.a"),
source,
builder,
options,
useDynamicMessage);
assertTrue(builder.hasRequiredMessage());
assertTrue(builder.getRequiredMessage().hasA());
assertFalse(builder.getRequiredMessage().hasB());
assertFalse(builder.getRequiredMessage().hasC());
merge(
new FieldMaskTree().addFieldPath("required_message.b").addFieldPath("required_message.c"),
source,
builder,
options,
useDynamicMessage);
try {
assertEquals(builder.build(), source);
} catch (UninitializedMessageException e) {
throw new AssertionError("required field isn't set", e);
}
}
private void testMergeImpl(boolean useDynamicMessage) throws Exception {
TestAllTypes value =
TestAllTypes.newBuilder()
.setOptionalInt32(1234)
@ -119,45 +184,51 @@ public class FieldMaskTreeTest extends TestCase {
// Test merging each individual field.
NestedTestAllTypes.Builder builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree().addFieldPath("payload.optional_int32").merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.optional_int32"),
source, builder, options, useDynamicMessage);
NestedTestAllTypes.Builder expected = NestedTestAllTypes.newBuilder();
expected.getPayloadBuilder().setOptionalInt32(1234);
assertEquals(expected.build(), builder.build());
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("payload.optional_nested_message")
.merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.optional_nested_message"),
source, builder, options, useDynamicMessage);
expected = NestedTestAllTypes.newBuilder();
expected.getPayloadBuilder().setOptionalNestedMessage(NestedMessage.newBuilder().setBb(5678));
assertEquals(expected.build(), builder.build());
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree().addFieldPath("payload.repeated_int32").merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.repeated_int32"),
source, builder, options, useDynamicMessage);
expected = NestedTestAllTypes.newBuilder();
expected.getPayloadBuilder().addRepeatedInt32(4321);
assertEquals(expected.build(), builder.build());
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("payload.repeated_nested_message")
.merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.repeated_nested_message"),
source, builder, options, useDynamicMessage);
expected = NestedTestAllTypes.newBuilder();
expected.getPayloadBuilder().addRepeatedNestedMessage(NestedMessage.newBuilder().setBb(8765));
assertEquals(expected.build(), builder.build());
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("child.payload.optional_int32")
.merge(source, builder, options);
merge(
new FieldMaskTree().addFieldPath("child.payload.optional_int32"),
source,
builder,
options,
useDynamicMessage);
expected = NestedTestAllTypes.newBuilder();
expected.getChildBuilder().getPayloadBuilder().setOptionalInt32(1234);
assertEquals(expected.build(), builder.build());
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("child.payload.optional_nested_message")
.merge(source, builder, options);
merge(
new FieldMaskTree().addFieldPath("child.payload.optional_nested_message"),
source,
builder,
options,
useDynamicMessage);
expected = NestedTestAllTypes.newBuilder();
expected
.getChildBuilder()
@ -166,17 +237,15 @@ public class FieldMaskTreeTest extends TestCase {
assertEquals(expected.build(), builder.build());
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("child.payload.repeated_int32")
.merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("child.payload.repeated_int32"),
source, builder, options, useDynamicMessage);
expected = NestedTestAllTypes.newBuilder();
expected.getChildBuilder().getPayloadBuilder().addRepeatedInt32(4321);
assertEquals(expected.build(), builder.build());
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("child.payload.repeated_nested_message")
.merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("child.payload.repeated_nested_message"),
source, builder, options, useDynamicMessage);
expected = NestedTestAllTypes.newBuilder();
expected
.getChildBuilder()
@ -186,23 +255,23 @@ public class FieldMaskTreeTest extends TestCase {
// Test merging all fields.
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("child")
.addFieldPath("payload")
.merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("child").addFieldPath("payload"),
source, builder, options, useDynamicMessage);
assertEquals(source, builder.build());
// Test repeated options.
builder = NestedTestAllTypes.newBuilder();
builder.getPayloadBuilder().addRepeatedInt32(1000);
new FieldMaskTree().addFieldPath("payload.repeated_int32").merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.repeated_int32"),
source, builder, options, useDynamicMessage);
// Default behavior is to append repeated fields.
assertEquals(2, builder.getPayload().getRepeatedInt32Count());
assertEquals(1000, builder.getPayload().getRepeatedInt32(0));
assertEquals(4321, builder.getPayload().getRepeatedInt32(1));
// Change to replace repeated fields.
options.setReplaceRepeatedFields(true);
new FieldMaskTree().addFieldPath("payload.repeated_int32").merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.repeated_int32"),
source, builder, options, useDynamicMessage);
assertEquals(1, builder.getPayload().getRepeatedInt32Count());
assertEquals(4321, builder.getPayload().getRepeatedInt32(0));
@ -210,7 +279,8 @@ public class FieldMaskTreeTest extends TestCase {
builder = NestedTestAllTypes.newBuilder();
builder.getPayloadBuilder().setOptionalInt32(1000);
builder.getPayloadBuilder().setOptionalUint32(2000);
new FieldMaskTree().addFieldPath("payload").merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload"),
source, builder, options, useDynamicMessage);
// Default behavior is to merge message fields.
assertEquals(1234, builder.getPayload().getOptionalInt32());
assertEquals(2000, builder.getPayload().getOptionalUint32());
@ -218,14 +288,14 @@ public class FieldMaskTreeTest extends TestCase {
// Test merging unset message fields.
NestedTestAllTypes clearedSource = source.toBuilder().clearPayload().build();
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree().addFieldPath("payload").merge(clearedSource, builder, options);
merge(new FieldMaskTree().addFieldPath("payload"),
clearedSource, builder, options, useDynamicMessage);
assertEquals(false, builder.hasPayload());
// Skip a message field if they are unset in both source and target.
builder = NestedTestAllTypes.newBuilder();
new FieldMaskTree()
.addFieldPath("payload.optional_int32")
.merge(clearedSource, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.optional_int32"),
clearedSource, builder, options, useDynamicMessage);
assertEquals(false, builder.hasPayload());
// Change to replace message fields.
@ -233,7 +303,8 @@ public class FieldMaskTreeTest extends TestCase {
builder = NestedTestAllTypes.newBuilder();
builder.getPayloadBuilder().setOptionalInt32(1000);
builder.getPayloadBuilder().setOptionalUint32(2000);
new FieldMaskTree().addFieldPath("payload").merge(source, builder, options);
merge(new FieldMaskTree().addFieldPath("payload"),
source, builder, options, useDynamicMessage);
assertEquals(1234, builder.getPayload().getOptionalInt32());
assertEquals(0, builder.getPayload().getOptionalUint32());
@ -241,7 +312,8 @@ public class FieldMaskTreeTest extends TestCase {
builder = NestedTestAllTypes.newBuilder();
builder.getPayloadBuilder().setOptionalInt32(1000);
builder.getPayloadBuilder().setOptionalUint32(2000);
new FieldMaskTree().addFieldPath("payload").merge(clearedSource, builder, options);
merge(new FieldMaskTree().addFieldPath("payload"),
clearedSource, builder, options, useDynamicMessage);
assertEquals(false, builder.hasPayload());
// Test merging unset primitive fields.
@ -249,18 +321,16 @@ public class FieldMaskTreeTest extends TestCase {
builder.getPayloadBuilder().clearOptionalInt32();
NestedTestAllTypes sourceWithPayloadInt32Unset = builder.build();
builder = source.toBuilder();
new FieldMaskTree()
.addFieldPath("payload.optional_int32")
.merge(sourceWithPayloadInt32Unset, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.optional_int32"),
sourceWithPayloadInt32Unset, builder, options, useDynamicMessage);
assertEquals(true, builder.getPayload().hasOptionalInt32());
assertEquals(0, builder.getPayload().getOptionalInt32());
// Change to clear unset primitive fields.
options.setReplacePrimitiveFields(true);
builder = source.toBuilder();
new FieldMaskTree()
.addFieldPath("payload.optional_int32")
.merge(sourceWithPayloadInt32Unset, builder, options);
merge(new FieldMaskTree().addFieldPath("payload.optional_int32"),
sourceWithPayloadInt32Unset, builder, options, useDynamicMessage);
assertEquals(true, builder.hasPayload());
assertEquals(false, builder.getPayload().hasOptionalInt32());
}

View File

@ -70,7 +70,6 @@ import java.io.StringReader;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Set;
@ -1188,10 +1187,14 @@ public class JsonFormatTest extends TestCase {
assertRoundTripEquals(message);
}
public void testDefaultGsonDoesNotHtmlEscape() throws Exception {
TestAllTypes message = TestAllTypes.newBuilder().setOptionalString("=").build();
assertEquals(
"{\n" + " \"optionalString\": \"=\"" + "\n}", JsonFormat.printer().print(message));
// Regression test for b/73832901. Make sure html tags are escaped.
public void testHtmlEscape() throws Exception {
TestAllTypes message = TestAllTypes.newBuilder().setOptionalString("</script>").build();
assertEquals("{\n \"optionalString\": \"\\u003c/script\\u003e\"\n}", toJsonString(message));
TestAllTypes.Builder builder = TestAllTypes.newBuilder();
JsonFormat.parser().merge(toJsonString(message), builder);
assertEquals(message.getOptionalString(), builder.getOptionalString());
}
public void testIncludingDefaultValueFields() throws Exception {

View File

@ -174,7 +174,7 @@ jspb.PrunerFunction;
/**
* A comparer function returns true if two protos are equal.
* @typedef {!function(?jspb.ConstBinaryMessage,
* @typedef {function(?jspb.ConstBinaryMessage,
* ?jspb.ConstBinaryMessage):boolean}
*/
jspb.ComparerFunction;

View File

@ -290,7 +290,9 @@ jspb.BinaryReader.prototype.nextField = function() {
nextWireType != jspb.BinaryConstants.WireType.DELIMITED &&
nextWireType != jspb.BinaryConstants.WireType.START_GROUP &&
nextWireType != jspb.BinaryConstants.WireType.END_GROUP) {
goog.asserts.fail('Invalid wire type');
goog.asserts.fail(
'Invalid wire type: %s (at position %s)', nextWireType,
this.fieldCursor_);
this.error_ = true;
return false;
}
@ -388,8 +390,7 @@ jspb.BinaryReader.prototype.skipFixed64Field = function() {
* Skips over the next group field in the binary stream.
*/
jspb.BinaryReader.prototype.skipGroup = function() {
// Keep a stack of start-group tags that must be matched by end-group tags.
var nestedGroups = [this.nextField_];
var previousField = this.nextField_;
do {
if (!this.nextField()) {
goog.asserts.fail('Unmatched start-group tag: stream EOF');
@ -397,19 +398,17 @@ jspb.BinaryReader.prototype.skipGroup = function() {
return;
}
if (this.nextWireType_ ==
jspb.BinaryConstants.WireType.START_GROUP) {
// Nested group start.
nestedGroups.push(this.nextField_);
} else if (this.nextWireType_ ==
jspb.BinaryConstants.WireType.END_GROUP) {
// Group end: check that it matches top-of-stack.
if (this.nextField_ != nestedGroups.pop()) {
if (this.nextField_ != previousField) {
goog.asserts.fail('Unmatched end-group tag');
this.error_ = true;
return;
}
return;
}
} while (nestedGroups.length > 0);
this.skipField();
} while (true);
};

View File

@ -679,9 +679,24 @@ describe('binaryReaderTest', function() {
writer.writeInt32(5, sentinel);
var dummyMessage = /** @type {!jspb.BinaryMessage} */({});
writer.writeGroup(5, dummyMessage, function() {
// Previously the skipGroup implementation was wrong, which only consume
// the decoder by nextField. This case is for making the previous
// implementation failed in skipGroup by an early end group tag.
// The reason is 44 = 5 * 8 + 4, this will be translated in to a field
// with number 5 and with type 4 (end group)
writer.writeInt64(44, 44);
// This will make previous implementation failed by invalid tag (7).
writer.writeInt64(42, 47);
writer.writeInt64(42, 42);
// This is for making the previous implementation failed by an invalid
// varint. The bytes have at least 9 consecutive minus byte, which will
// fail in this.nextField for previous implementation.
writer.writeBytes(43, [255, 255, 255, 255, 255, 255, 255, 255, 255, 255]);
writer.writeGroup(6, dummyMessage, function() {
writer.writeInt64(84, 42);
writer.writeInt64(84, 44);
writer.writeBytes(
43, [255, 255, 255, 255, 255, 255, 255, 255, 255, 255]);
});
});

View File

@ -971,8 +971,9 @@ jspb.utils.byteSourceToUint8Array = function(data) {
return /** @type {!Uint8Array} */(new Uint8Array(data));
}
if (data.constructor === Buffer) {
return /** @type {!Uint8Array} */(new Uint8Array(data));
if (typeof Buffer != 'undefined' && data.constructor === Buffer) {
return /** @type {!Uint8Array} */ (
new Uint8Array(/** @type {?} */ (data)));
}
if (data.constructor === Array) {

View File

@ -136,7 +136,7 @@ jspb.Map.prototype.toArray = function() {
*
* @param {boolean=} includeInstance Whether to include the JSPB instance for
* transitional soy proto support: http://goto/soy-param-migration
* @param {!function((boolean|undefined),V):!Object=} valueToObject
* @param {function((boolean|undefined),V):!Object=} valueToObject
* The static toObject() method, if V is a message type.
* @return {!Array<!Array<!Object>>}
*/
@ -165,9 +165,9 @@ jspb.Map.prototype.toObject = function(includeInstance, valueToObject) {
*
* @template K, V
* @param {!Array<!Array<!Object>>} entries
* @param {!function(new:V,?=)} valueCtor
* @param {function(new:V,?=)} valueCtor
* The constructor for type V.
* @param {!function(!Object):V} valueFromObject
* @param {function(!Object):V} valueFromObject
* The fromObject function for type V.
* @return {!jspb.Map<K, V>}
*/
@ -410,9 +410,9 @@ jspb.Map.prototype.has = function(key) {
* number.
* @param {number} fieldNumber
* @param {!jspb.BinaryWriter} writer
* @param {!function(this:jspb.BinaryWriter,number,K)} keyWriterFn
* @param {function(this:jspb.BinaryWriter,number,K)} keyWriterFn
* The method on BinaryWriter that writes type K to the stream.
* @param {!function(this:jspb.BinaryWriter,number,V,?=)|
* @param {function(this:jspb.BinaryWriter,number,V,?=)|
* function(this:jspb.BinaryWriter,number,V,?)} valueWriterFn
* The method on BinaryWriter that writes type V to the stream. May be
* writeMessage, in which case the second callback arg form is used.
@ -448,10 +448,10 @@ jspb.Map.prototype.serializeBinary = function(
* @template K, V
* @param {!jspb.Map} map
* @param {!jspb.BinaryReader} reader
* @param {!function(this:jspb.BinaryReader):K} keyReaderFn
* @param {function(this:jspb.BinaryReader):K} keyReaderFn
* The method on BinaryReader that reads type K from the stream.
*
* @param {!function(this:jspb.BinaryReader):V|
* @param {function(this:jspb.BinaryReader):V|
* function(this:jspb.BinaryReader,V,
* function(V,!jspb.BinaryReader))} valueReaderFn
* The method on BinaryReader that reads type V from the stream. May be

View File

@ -439,9 +439,19 @@ jspb.Message.isArray_ = function(o) {
* @private
*/
jspb.Message.initPivotAndExtensionObject_ = function(msg, suggestedPivot) {
if (msg.array.length) {
var foundIndex = msg.array.length - 1;
var obj = msg.array[foundIndex];
// There are 3 variants that need to be dealt with which are the
// combination of whether there exists an extension object (EO) and
// whether there is a suggested pivot (SP).
//
// EO, ? : pivot is the index of the EO
// no-EO, no-SP: pivot is MAX_INT
// no-EO, SP : pivot is the max(lastindex + 1, SP)
var msgLength = msg.array.length;
var lastIndex = -1;
if (msgLength) {
lastIndex = msgLength - 1;
var obj = msg.array[lastIndex];
// Normal fields are never objects, so we can be sure that if we find an
// object here, then it's the extension object. However, we must ensure that
// the object is not an array, since arrays are valid field values.
@ -449,14 +459,17 @@ jspb.Message.initPivotAndExtensionObject_ = function(msg, suggestedPivot) {
// in Safari on iOS 8. See the description of CL/86511464 for details.
if (obj && typeof obj == 'object' && !jspb.Message.isArray_(obj) &&
!(jspb.Message.SUPPORTS_UINT8ARRAY_ && obj instanceof Uint8Array)) {
msg.pivot_ = jspb.Message.getFieldNumber_(msg, foundIndex);
msg.pivot_ = jspb.Message.getFieldNumber_(msg, lastIndex);
msg.extensionObject_ = obj;
return;
}
}
if (suggestedPivot > -1) {
msg.pivot_ = suggestedPivot;
// If a extension object is not present, set the pivot value as being
// after the last value in the array to avoid overwriting values, etc.
msg.pivot_ = Math.max(
suggestedPivot, jspb.Message.getFieldNumber_(msg, lastIndex + 1));
// Avoid changing the shape of the proto with an empty extension object by
// deferring the materialization of the extension object until the first
// time a field set into it (may be due to getting a repeated proto field
@ -922,17 +935,6 @@ jspb.Message.setProto3IntField = function(msg, fieldNumber, value) {
};
/**
* Sets the value of a non-extension integer, handled as string, field of a proto3
* @param {!jspb.Message} msg A jspb proto.
* @param {number} fieldNumber The field number.
* @param {number} value New value
* @protected
*/
jspb.Message.setProto3StringIntField = function(msg, fieldNumber, value) {
jspb.Message.setFieldIgnoringDefault_(msg, fieldNumber, value, '0');
};
/**
* Sets the value of a non-extension floating point field of a proto3
* @param {!jspb.Message} msg A jspb proto.
@ -993,12 +995,22 @@ jspb.Message.setProto3EnumField = function(msg, fieldNumber, value) {
};
/**
* Sets the value of a non-extension int field of a proto3 that has jstype set
* to String.
* @param {!jspb.Message} msg A jspb proto.
* @param {number} fieldNumber The field number.
* @param {string} value New value
* @protected
*/
jspb.Message.setProto3StringIntField = function(msg, fieldNumber, value) {
jspb.Message.setFieldIgnoringDefault_(msg, fieldNumber, value, "0");
};
/**
* Sets the value of a non-extension primitive field, with proto3 (non-nullable
* primitives) semantics of ignoring values that are equal to the type's
* default.
* @template T
* @param {!jspb.Message} msg A jspb proto.
* @param {number} fieldNumber The field number.
* @param {!Uint8Array|string|number|boolean|undefined} value New value
@ -1007,7 +1019,7 @@ jspb.Message.setProto3EnumField = function(msg, fieldNumber, value) {
*/
jspb.Message.setFieldIgnoringDefault_ = function(
msg, fieldNumber, value, defaultValue) {
if (value != defaultValue) {
if (value !== defaultValue) {
jspb.Message.setField(msg, fieldNumber, value);
} else {
msg.array[jspb.Message.getIndex_(msg, fieldNumber)] = null;
@ -1127,7 +1139,7 @@ jspb.Message.getWrapperField = function(msg, ctor, fieldNumber, opt_required) {
* @param {!jspb.Message} msg A jspb proto.
* @param {function(new:jspb.Message, Array)} ctor Constructor for the field.
* @param {number} fieldNumber The field number.
* @return {Array<!jspb.Message>} The repeated field as an array of protos.
* @return {!Array<!jspb.Message>} The repeated field as an array of protos.
* @protected
*/
jspb.Message.getRepeatedWrapperField = function(msg, ctor, fieldNumber) {

View File

@ -73,6 +73,7 @@ goog.require('proto.jspb.test.Simple1');
goog.require('proto.jspb.test.Simple2');
goog.require('proto.jspb.test.SpecialCases');
goog.require('proto.jspb.test.TestClone');
goog.require('proto.jspb.test.TestCloneExtension');
goog.require('proto.jspb.test.TestEndsWithBytes');
goog.require('proto.jspb.test.TestGroup');
goog.require('proto.jspb.test.TestGroup1');

View File

@ -165,6 +165,13 @@ message TestClone {
extensions 10 to max;
}
message TestCloneExtension {
extend TestClone {
optional TestCloneExtension low_ext = 11;
}
optional int32 f = 1;
}
message CloneExtension {
extend TestClone {
optional CloneExtension ext_field = 100;

View File

@ -410,7 +410,8 @@ class TextFormatTest(unittest.TestCase):
text = 'optional_nested_enum: BARR'
self.assertRaisesWithMessage(
text_format.ParseError,
('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
('1:23 : \'optional_nested_enum: BARR\': '
'Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value named BARR.'),
text_format.Merge, text, message)
@ -418,7 +419,8 @@ class TextFormatTest(unittest.TestCase):
text = 'optional_nested_enum: 100'
self.assertRaisesWithMessage(
text_format.ParseError,
('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
('1:23 : \'optional_nested_enum: 100\': '
'Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value with number 100.'),
text_format.Merge, text, message)
@ -427,7 +429,8 @@ class TextFormatTest(unittest.TestCase):
text = 'optional_int32: bork'
self.assertRaisesWithMessage(
text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'),
('1:17 : \'optional_int32: bork\': '
'Couldn\'t parse integer: bork'),
text_format.Merge, text, message)
def testMergeStringFieldUnescape(self):

View File

@ -76,6 +76,9 @@ class DescriptorDatabase(object):
self._AddSymbol(name, file_desc_proto)
for enum in file_desc_proto.enum_type:
self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto)
for enum_value in enum.value:
self._file_desc_protos_by_symbol[
'.'.join((package, enum_value.name))] = file_desc_proto
for extension in file_desc_proto.extension:
self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto)
for service in file_desc_proto.service:
@ -133,6 +136,14 @@ class DescriptorDatabase(object):
top_level, _, _ = symbol.rpartition('.')
return self._file_desc_protos_by_symbol[top_level]
def FindFileContainingExtension(self, extendee_name, extension_number):
# TODO(jieluo): implement this API.
return None
def FindAllExtensionNumbers(self, extendee_name):
# TODO(jieluo): implement this API.
return []
def _AddSymbol(self, name, file_desc_proto):
if name in self._file_desc_protos_by_symbol:
warn_msg = ('Conflict register for file "' + file_desc_proto.name +

View File

@ -131,33 +131,46 @@ class DescriptorPool(object):
# TODO(jieluo): Remove _file_desc_by_toplevel_extension after
# maybe year 2020 for compatibility issue (with 3.4.1 only).
self._file_desc_by_toplevel_extension = {}
self._top_enum_values = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
# full name or its tag number.
self._extensions_by_name = collections.defaultdict(dict)
self._extensions_by_number = collections.defaultdict(dict)
def _CheckConflictRegister(self, desc):
def _CheckConflictRegister(self, desc, desc_name, file_name):
"""Check if the descriptor name conflicts with another of the same name.
Args:
desc: Descriptor of a message, enum, service or extension.
desc: Descriptor of a message, enum, service, extension or enum value.
desc_name: the full name of desc.
file_name: The file name of descriptor.
"""
desc_name = desc.full_name
for register, descriptor_type in [
(self._descriptors, descriptor.Descriptor),
(self._enum_descriptors, descriptor.EnumDescriptor),
(self._service_descriptors, descriptor.ServiceDescriptor),
(self._toplevel_extensions, descriptor.FieldDescriptor)]:
(self._toplevel_extensions, descriptor.FieldDescriptor),
(self._top_enum_values, descriptor.EnumValueDescriptor)]:
if desc_name in register:
file_name = register[desc_name].file.name
old_desc = register[desc_name]
if isinstance(old_desc, descriptor.EnumValueDescriptor):
old_file = old_desc.type.file.name
else:
old_file = old_desc.file.name
if not isinstance(desc, descriptor_type) or (
file_name != desc.file.name):
warn_msg = ('Conflict register for file "' + desc.file.name +
old_file != file_name):
warn_msg = ('Conflict register for file "' + file_name +
'": ' + desc_name +
' is already defined in file "' +
file_name + '"')
old_file + '"')
if isinstance(desc, descriptor.EnumValueDescriptor):
warn_msg += ('\nNote: enum values appear as '
'siblings of the enum type instead of '
'children of it.')
warnings.warn(warn_msg, RuntimeWarning)
return
def Add(self, file_desc_proto):
@ -196,7 +209,7 @@ class DescriptorPool(object):
if not isinstance(desc, descriptor.Descriptor):
raise TypeError('Expected instance of descriptor.Descriptor.')
self._CheckConflictRegister(desc)
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._descriptors[desc.full_name] = desc
self._AddFileDescriptor(desc.file)
@ -213,8 +226,26 @@ class DescriptorPool(object):
if not isinstance(enum_desc, descriptor.EnumDescriptor):
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
self._CheckConflictRegister(enum_desc)
file_name = enum_desc.file.name
self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
self._enum_descriptors[enum_desc.full_name] = enum_desc
# Top enum values need to be indexed.
# Count the number of dots to see whether the enum is toplevel or nested
# in a message. We cannot use enum_desc.containing_type at this stage.
if enum_desc.file.package:
top_level = (enum_desc.full_name.count('.')
- enum_desc.file.package.count('.') == 1)
else:
top_level = enum_desc.full_name.count('.') == 0
if top_level:
file_name = enum_desc.file.name
package = enum_desc.file.package
for enum_value in enum_desc.values:
full_name = _NormalizeFullyQualifiedName(
'.'.join((package, enum_value.name)))
self._CheckConflictRegister(enum_value, full_name, file_name)
self._top_enum_values[full_name] = enum_value
self._AddFileDescriptor(enum_desc.file)
def AddServiceDescriptor(self, service_desc):
@ -227,7 +258,8 @@ class DescriptorPool(object):
if not isinstance(service_desc, descriptor.ServiceDescriptor):
raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
self._CheckConflictRegister(service_desc)
self._CheckConflictRegister(service_desc, service_desc.full_name,
service_desc.file.name)
self._service_descriptors[service_desc.full_name] = service_desc
def AddExtensionDescriptor(self, extension):
@ -247,7 +279,6 @@ class DescriptorPool(object):
raise TypeError('Expected an extension descriptor.')
if extension.extension_scope is None:
self._CheckConflictRegister(extension)
self._toplevel_extensions[extension.full_name] = extension
try:
@ -348,6 +379,30 @@ class DescriptorPool(object):
"""
symbol = _NormalizeFullyQualifiedName(symbol)
try:
return self._InternalFindFileContainingSymbol(symbol)
except KeyError:
pass
try:
# Try fallback database. Build and find again if possible.
self._FindFileContainingSymbolInDb(symbol)
return self._InternalFindFileContainingSymbol(symbol)
except KeyError:
raise KeyError('Cannot find a file containing %s' % symbol)
def _InternalFindFileContainingSymbol(self, symbol):
"""Gets the already built FileDescriptor containing the specified symbol.
Args:
symbol: The name of the symbol to search for.
Returns:
A FileDescriptor that contains the specified symbol.
Raises:
KeyError: if the file cannot be found in the pool.
"""
try:
return self._descriptors[symbol].file
except KeyError:
@ -364,7 +419,7 @@ class DescriptorPool(object):
pass
try:
return self._FindFileContainingSymbolInDb(symbol)
return self._top_enum_values[symbol].type.file
except KeyError:
pass
@ -373,13 +428,15 @@ class DescriptorPool(object):
except KeyError:
pass
# Try nested extensions inside a message.
message_name, _, extension_name = symbol.rpartition('.')
# Try fields, enum values and nested extensions inside a message.
top_name, _, sub_name = symbol.rpartition('.')
try:
message = self.FindMessageTypeByName(message_name)
assert message.extensions_by_name[extension_name]
message = self.FindMessageTypeByName(top_name)
assert (sub_name in message.extensions_by_name or
sub_name in message.fields_by_name or
sub_name in message.enum_values_by_name)
return message.file
except KeyError:
except (KeyError, AssertionError):
raise KeyError('Cannot find a file containing %s' % symbol)
def FindMessageTypeByName(self, full_name):
@ -499,7 +556,11 @@ class DescriptorPool(object):
KeyError: when no extension with the given number is known for the
specified message.
"""
return self._extensions_by_number[message_descriptor][number]
try:
return self._extensions_by_number[message_descriptor][number]
except KeyError:
self._TryLoadExtensionFromDB(message_descriptor, number)
return self._extensions_by_number[message_descriptor][number]
def FindAllExtensions(self, message_descriptor):
"""Gets all the known extension of a given message.
@ -513,8 +574,57 @@ class DescriptorPool(object):
Returns:
A list of FieldDescriptor describing the extensions.
"""
# Fallback to descriptor db if FindAllExtensionNumbers is provided.
if self._descriptor_db and hasattr(
self._descriptor_db, 'FindAllExtensionNumbers'):
full_name = message_descriptor.full_name
all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
for number in all_numbers:
if number in self._extensions_by_number[message_descriptor]:
continue
self._TryLoadExtensionFromDB(message_descriptor, number)
return list(self._extensions_by_number[message_descriptor].values())
def _TryLoadExtensionFromDB(self, message_descriptor, number):
"""Try to Load extensions from decriptor db.
Args:
message_descriptor: descriptor of the extended message.
number: the extension number that needs to be loaded.
"""
if not self._descriptor_db:
return
# Only supported when FindFileContainingExtension is provided.
if not hasattr(
self._descriptor_db, 'FindFileContainingExtension'):
return
full_name = message_descriptor.full_name
file_proto = self._descriptor_db.FindFileContainingExtension(
full_name, number)
if file_proto is None:
return
try:
file_desc = self._ConvertFileProtoToFileDescriptor(file_proto)
for extension in file_desc.extensions_by_name.values():
self._extensions_by_number[extension.containing_type][
extension.number] = extension
self._extensions_by_name[extension.containing_type][
extension.full_name] = extension
for message_type in file_desc.message_types_by_name.values():
for extension in message_type.extensions:
self._extensions_by_number[extension.containing_type][
extension.number] = extension
self._extensions_by_name[extension.containing_type][
extension.full_name] = extension
except:
warn_msg = ('Unable to load proto file %s for extension number %d.' %
(file_proto.name, number))
warnings.warn(warn_msg, RuntimeWarning)
def FindServiceByName(self, full_name):
"""Loads the named service descriptor from the pool.
@ -532,6 +642,23 @@ class DescriptorPool(object):
self._FindFileContainingSymbolInDb(full_name)
return self._service_descriptors[full_name]
def FindMethodByName(self, full_name):
"""Loads the named service method descriptor from the pool.
Args:
full_name: The full name of the method descriptor to load.
Returns:
The method descriptor for the service method.
Raises:
KeyError: if the method cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
service_name, _, method_name = full_name.rpartition('.')
service_descriptor = self.FindServiceByName(service_name)
return service_descriptor.methods_by_name[method_name]
def _FindFileContainingSymbolInDb(self, symbol):
"""Finds the file in descriptor DB containing the specified symbol.
@ -567,7 +694,6 @@ class DescriptorPool(object):
Returns:
A FileDescriptor matching the passed in proto.
"""
if file_proto.name not in self._file_descriptors:
built_deps = list(self._GetDeps(file_proto.dependency))
direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
@ -604,7 +730,7 @@ class DescriptorPool(object):
for enum_type in file_proto.enum_type:
file_descriptor.enum_types_by_name[enum_type.name] = (
self._ConvertEnumDescriptor(enum_type, file_proto.package,
file_descriptor, None, scope))
file_descriptor, None, scope, True))
for index, extension_proto in enumerate(file_proto.extension):
extension_desc = self._MakeFieldDescriptor(
@ -616,6 +742,8 @@ class DescriptorPool(object):
file_descriptor.package, scope)
file_descriptor.extensions_by_name[extension_desc.name] = (
extension_desc)
self._file_desc_by_toplevel_extension[extension_desc.full_name] = (
file_descriptor)
for desc_proto in file_proto.message_type:
self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
@ -673,7 +801,8 @@ class DescriptorPool(object):
nested, desc_name, file_desc, scope, syntax)
for nested in desc_proto.nested_type]
enums = [
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
scope, False)
for enum in desc_proto.enum_type]
fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
for index, field in enumerate(desc_proto.field)]
@ -718,12 +847,12 @@ class DescriptorPool(object):
fields[field_index].containing_oneof = oneofs[oneof_index]
scope[_PrefixWithDot(desc_name)] = desc
self._CheckConflictRegister(desc)
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._descriptors[desc_name] = desc
return desc
def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
containing_type=None, scope=None):
containing_type=None, scope=None, top_level=False):
"""Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
Args:
@ -732,6 +861,8 @@ class DescriptorPool(object):
file_desc: The file containing the enum descriptor.
containing_type: The type containing this enum.
scope: Scope containing available types.
top_level: If True, the enum is a top level symbol. If False, the enum
is defined inside a message.
Returns:
The added descriptor
@ -757,8 +888,17 @@ class DescriptorPool(object):
containing_type=containing_type,
options=_OptionsOrNone(enum_proto))
scope['.%s' % enum_name] = desc
self._CheckConflictRegister(desc)
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._enum_descriptors[enum_name] = desc
# Add top level enum values.
if top_level:
for value in values:
full_name = _NormalizeFullyQualifiedName(
'.'.join((package, value.name)))
self._CheckConflictRegister(value, full_name, file_name)
self._top_enum_values[full_name] = value
return desc
def _MakeFieldDescriptor(self, field_proto, message_name, index,
@ -885,6 +1025,8 @@ class DescriptorPool(object):
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
field_desc.default_value = text_encoding.CUnescape(
field_proto.default_value)
elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
field_desc.default_value = None
else:
# All other types are of the "int" type.
field_desc.default_value = int(field_proto.default_value)
@ -901,6 +1043,8 @@ class DescriptorPool(object):
field_desc.default_value = field_desc.enum_type.values[0].number
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
field_desc.default_value = b''
elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
field_desc.default_value = None
else:
# All other types are of the "int" type.
field_desc.default_value = 0
@ -954,7 +1098,7 @@ class DescriptorPool(object):
methods=methods,
options=_OptionsOrNone(service_proto),
file=file_desc)
self._CheckConflictRegister(desc)
self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._service_descriptors[service_name] = desc
return desc

View File

@ -0,0 +1,30 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -145,29 +145,3 @@ def Version():
# For internal use only
def IsPythonDefaultSerializationDeterministic():
return _python_deterministic_proto_serialization
# DO NOT USE: For migration and testing only. Will be removed when Proto3
# defaults to preserve unknowns.
if _implementation_type == 'cpp':
try:
# pylint: disable=g-import-not-at-top
from google.protobuf.pyext import _message
def GetPythonProto3PreserveUnknownsDefault():
return _message.GetPythonProto3PreserveUnknownsDefault()
def SetPythonProto3PreserveUnknownsDefault(preserve):
_message.SetPythonProto3PreserveUnknownsDefault(preserve)
except ImportError:
# Unrecognized cpp implementation. Skipping the unknown fields APIs.
pass
else:
_python_proto3_preserve_unknowns_default = True
def GetPythonProto3PreserveUnknownsDefault():
return _python_proto3_preserve_unknowns_default
def SetPythonProto3PreserveUnknownsDefault(preserve):
global _python_proto3_preserve_unknowns_default
_python_proto3_preserve_unknowns_default = preserve

View File

@ -628,3 +628,130 @@ class MessageMap(MutableMapping):
def GetEntryClass(self):
return self._entry_descriptor._concrete_class
class _UnknownField(object):
"""A parsed unknown field."""
# Disallows assignment to other attributes.
__slots__ = ['_field_number', '_wire_type', '_data']
def __init__(self, field_number, wire_type, data):
self._field_number = field_number
self._wire_type = wire_type
self._data = data
return
def __lt__(self, other):
# pylint: disable=protected-access
return self._field_number < other._field_number
def __eq__(self, other):
if self is other:
return True
# pylint: disable=protected-access
return (self._field_number == other._field_number and
self._wire_type == other._wire_type and
self._data == other._data)
class UnknownFieldRef(object):
def __init__(self, parent, index):
self._parent = parent
self._index = index
return
def _check_valid(self):
if not self._parent:
raise ValueError('UnknownField does not exist. '
'The parent message might be cleared.')
if self._index >= len(self._parent):
raise ValueError('UnknownField does not exist. '
'The parent message might be cleared.')
@property
def field_number(self):
self._check_valid()
# pylint: disable=protected-access
return self._parent._internal_get(self._index)._field_number
@property
def wire_type(self):
self._check_valid()
# pylint: disable=protected-access
return self._parent._internal_get(self._index)._wire_type
@property
def data(self):
self._check_valid()
# pylint: disable=protected-access
return self._parent._internal_get(self._index)._data
class UnknownFieldSet(object):
"""UnknownField container"""
# Disallows assignment to other attributes.
__slots__ = ['_values']
def __init__(self):
self._values = []
def __getitem__(self, index):
if self._values is None:
raise ValueError('UnknownFields does not exist. '
'The parent message might be cleared.')
size = len(self._values)
if index < 0:
index += size
if index < 0 or index >= size:
raise IndexError('index %d out of range'.index)
return UnknownFieldRef(self, index)
def _internal_get(self, index):
return self._values[index]
def __len__(self):
if self._values is None:
raise ValueError('UnknownFields does not exist. '
'The parent message might be cleared.')
return len(self._values)
def _add(self, field_number, wire_type, data):
unknown_field = _UnknownField(field_number, wire_type, data)
self._values.append(unknown_field)
return unknown_field
def __iter__(self):
for i in range(len(self)):
yield UnknownFieldRef(self, i)
def _extend(self, other):
if other is None:
return
# pylint: disable=protected-access
self._values.extend(other._values)
def __eq__(self, other):
if self is other:
return True
# Sort unknown fields because their order shouldn't
# affect equality test.
values = list(self._values)
if other is None:
return not values
values.sort()
# pylint: disable=protected-access
other_values = sorted(other._values)
return values == other_values
def _clear(self):
for value in self._values:
# pylint: disable=protected-access
if isinstance(value._data, UnknownFieldSet):
value._data._clear() # pylint: disable=protected-access
self._values = None

View File

@ -86,7 +86,11 @@ import six
if six.PY3:
long = int
else:
import re # pylint: disable=g-import-not-at-top
_SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]'))
from google.protobuf.internal import containers
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
@ -167,7 +171,7 @@ _DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
def ReadTag(buffer, pos):
"""Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
"""Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
We return the raw bytes of the tag rather than decoding them. The raw
bytes can then be used to look up the proper decoder. This effectively allows
@ -175,13 +179,21 @@ def ReadTag(buffer, pos):
for work that is done in C (searching for a byte string in a hash table).
In a low-level language it would be much cheaper to decode the varint and
use that, but not in Python.
"""
Args:
buffer: memoryview object of the encoded bytes
pos: int of the current position to start from
Returns:
Tuple[bytes, int] of the tag data and new position.
"""
start = pos
while six.indexbytes(buffer, pos) & 0x80:
pos += 1
pos += 1
return (six.binary_type(buffer[start:pos]), pos)
tag_bytes = buffer[start:pos].tobytes()
return tag_bytes, pos
# --------------------------------------------------------------------
@ -295,10 +307,20 @@ def _FloatDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
"""Decode serialized float to a float and new position.
Args:
buffer: memoryview of the serialized bytes
pos: int, position in the memory view to start at.
Returns:
Tuple[float, int] of the deserialized float value and new position
in the serialized data.
"""
# We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
new_pos = pos + 4
float_bytes = buffer[pos:new_pos]
float_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set, then it's non-finite.
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
@ -329,10 +351,20 @@ def _DoubleDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
"""Decode serialized double to a double and new position.
Args:
buffer: memoryview of the serialized bytes.
pos: int, position in the memory view to start at.
Returns:
Tuple[float, int] of the decoded double value and new position
in the serialized data.
"""
# We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
new_pos = pos + 8
double_bytes = buffer[pos:new_pos]
double_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set and at least one significand
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it
@ -355,6 +387,18 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
if is_packed:
local_DecodeVarint = _DecodeVarint
def DecodePackedField(buffer, pos, end, message, field_dict):
"""Decode serialized packed enum to its value and a new position.
Args:
buffer: memoryview of the serialized bytes.
pos: int, position in the memory view to start at.
end: int, end position of serialized data
message: Message object to store unknown fields in
field_dict: Map[Descriptor, Any] to store decoded values in.
Returns:
int, new position in serialized data.
"""
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@ -365,6 +409,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
while pos < endpoint:
value_start_pos = pos
(element, pos) = _DecodeSignedVarint32(buffer, pos)
# pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
@ -372,8 +417,10 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
message._unknown_fields = []
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
message._unknown_fields.append(
(tag_bytes, buffer[value_start_pos:pos]))
(tag_bytes, buffer[value_start_pos:pos].tobytes()))
# pylint: enable=protected-access
if pos > endpoint:
if element in enum_type.values_by_number:
del value[-1] # Discard corrupt value.
@ -386,18 +433,32 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
"""Decode serialized repeated enum to its value and a new position.
Args:
buffer: memoryview of the serialized bytes.
pos: int, position in the memory view to start at.
end: int, end position of serialized data
message: Message object to store unknown fields in
field_dict: Map[Descriptor, Any] to store decoded values in.
Returns:
int, new position in serialized data.
"""
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
(element, new_pos) = _DecodeSignedVarint32(buffer, pos)
# pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
if not message._unknown_fields:
message._unknown_fields = []
message._unknown_fields.append(
(tag_bytes, buffer[pos:new_pos]))
(tag_bytes, buffer[pos:new_pos].tobytes()))
# pylint: enable=protected-access
# Predict that the next tag is another copy of the same repeated
# field.
pos = new_pos + tag_len
@ -409,10 +470,23 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
"""Decode serialized repeated enum to its value and a new position.
Args:
buffer: memoryview of the serialized bytes.
pos: int, position in the memory view to start at.
end: int, end position of serialized data
message: Message object to store unknown fields in
field_dict: Map[Descriptor, Any] to store decoded values in.
Returns:
int, new position in serialized data.
"""
value_start_pos = pos
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
# pylint: disable=protected-access
if enum_value in enum_type.values_by_number:
field_dict[key] = enum_value
else:
@ -421,7 +495,8 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
message._unknown_fields.append(
(tag_bytes, buffer[value_start_pos:pos]))
(tag_bytes, buffer[value_start_pos:pos].tobytes()))
# pylint: enable=protected-access
return pos
return DecodeField
@ -458,20 +533,33 @@ BoolDecoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
is_strict_utf8=False):
"""Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint
local_unicode = six.text_type
def _ConvertToUnicode(byte_str):
def _ConvertToUnicode(memview):
"""Convert byte to unicode."""
byte_str = memview.tobytes()
try:
return local_unicode(byte_str, 'utf-8')
value = local_unicode(byte_str, 'utf-8')
except UnicodeDecodeError as e:
# add more information to the error message and re-raise it.
e.reason = '%s in field: %s' % (e, key.full_name)
raise
if is_strict_utf8 and six.PY2:
if _SURROGATE_PATTERN.search(value):
reason = ('String field %s contains invalid UTF-8 data when parsing'
'a protocol buffer: surrogates not allowed. Use'
'the bytes type if you intend to send raw bytes.') % (
key.full_name)
raise message.DecodeError(reason)
return value
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
@ -523,7 +611,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
value.append(buffer[pos:new_pos])
value.append(buffer[pos:new_pos].tobytes())
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@ -536,7 +624,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
field_dict[key] = buffer[pos:new_pos]
field_dict[key] = buffer[pos:new_pos].tobytes()
return new_pos
return DecodeField
@ -665,6 +753,18 @@ def MessageSetItemDecoder(descriptor):
local_SkipField = SkipField
def DecodeItem(buffer, pos, end, message, field_dict):
"""Decode serialized message set to its value and new position.
Args:
buffer: memoryview of the serialized bytes.
pos: int, position in the memory view to start at.
end: int, end position of serialized data
message: Message object to store unknown fields in
field_dict: Map[Descriptor, Any] to store decoded values in.
Returns:
int, new position in serialized data.
"""
message_set_item_start = pos
type_id = -1
message_start = -1
@ -695,6 +795,7 @@ def MessageSetItemDecoder(descriptor):
raise _DecodeError('MessageSet item missing message.')
extension = message.Extensions._FindExtensionByNumber(type_id)
# pylint: disable=protected-access
if extension is not None:
value = field_dict.get(extension)
if value is None:
@ -707,8 +808,9 @@ def MessageSetItemDecoder(descriptor):
else:
if not message._unknown_fields:
message._unknown_fields = []
message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
buffer[message_set_item_start:pos]))
message._unknown_fields.append(
(MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
# pylint: enable=protected-access
return pos
@ -767,7 +869,7 @@ def _SkipVarint(buffer, pos, end):
# Previously ord(buffer[pos]) raised IndexError when pos is out of range.
# With this code, ord(b'') raises TypeError. Both are handled in
# python_message.py to generate a 'Truncated message' error.
while ord(buffer[pos:pos+1]) & 0x80:
while ord(buffer[pos:pos+1].tobytes()) & 0x80:
pos += 1
pos += 1
if pos > end:
@ -782,6 +884,13 @@ def _SkipFixed64(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
def _DecodeFixed64(buffer, pos):
"""Decode a fixed64."""
new_pos = pos + 8
return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
def _SkipLengthDelimited(buffer, pos, end):
"""Skip a length-delimited value. Returns the new position."""
@ -791,6 +900,7 @@ def _SkipLengthDelimited(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
@ -801,11 +911,53 @@ def _SkipGroup(buffer, pos, end):
return pos
pos = new_pos
def _DecodeGroup(buffer, pos):
"""Decode group. Returns the UnknownFieldSet and new position."""
unknown_field_set = containers.UnknownFieldSet()
while 1:
(tag_bytes, pos) = ReadTag(buffer, pos)
(tag, _) = _DecodeVarint(tag_bytes, 0)
field_number, wire_type = wire_format.UnpackTag(tag)
if wire_type == wire_format.WIRETYPE_END_GROUP:
break
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
# pylint: disable=protected-access
unknown_field_set._add(field_number, wire_type, data)
return (unknown_field_set, pos)
def _DecodeUnknownField(buffer, pos, wire_type):
"""Decode a unknown field. Returns the UnknownField and new position."""
if wire_type == wire_format.WIRETYPE_VARINT:
(data, pos) = _DecodeVarint(buffer, pos)
elif wire_type == wire_format.WIRETYPE_FIXED64:
(data, pos) = _DecodeFixed64(buffer, pos)
elif wire_type == wire_format.WIRETYPE_FIXED32:
(data, pos) = _DecodeFixed32(buffer, pos)
elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
(size, pos) = _DecodeVarint(buffer, pos)
data = buffer[pos:pos+size]
pos += size
elif wire_type == wire_format.WIRETYPE_START_GROUP:
(data, pos) = _DecodeGroup(buffer, pos)
elif wire_type == wire_format.WIRETYPE_END_GROUP:
return (0, -1)
else:
raise _DecodeError('Wrong wire type in tag.')
return (data, pos)
def _EndGroup(buffer, pos, end):
"""Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
return -1
def _SkipFixed32(buffer, pos, end):
"""Skip a fixed32 value. Returns the new position."""
@ -814,6 +966,14 @@ def _SkipFixed32(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
def _DecodeFixed32(buffer, pos):
"""Decode a fixed32."""
new_pos = pos + 4
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
def _RaiseInvalidWireType(buffer, pos, end):
"""Skip function for unknown wire types. Raises an exception."""

View File

@ -43,6 +43,7 @@ import warnings
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf.internal import no_package_pb2
from google.protobuf import descriptor_database
@ -52,7 +53,10 @@ class DescriptorDatabaseTest(unittest.TestCase):
db = descriptor_database.DescriptorDatabase()
file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
factory_test2_pb2.DESCRIPTOR.serialized_pb)
file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString(
no_package_pb2.DESCRIPTOR.serialized_pb)
db.Add(file_desc_proto)
db.Add(file_desc_proto2)
self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto'))
@ -76,6 +80,10 @@ class DescriptorDatabaseTest(unittest.TestCase):
# Can find enum value.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0'))
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.FACTORY_2_VALUE_0'))
self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
'.NO_PACKAGE_VALUE_0'))
# Can find top level extension.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.another_field'))

View File

@ -36,7 +36,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)'
import copy
import os
import sys
import warnings
try:
@ -55,6 +54,7 @@ from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf.internal import file_options_test_pb2
from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import no_package_pb2
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
@ -120,7 +120,6 @@ class DescriptorPoolTestBase(object):
self.assertIsInstance(file_desc5, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/unittest.proto',
file_desc5.name)
# Tests the generated pool.
assert descriptor_pool.Default().FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@ -129,6 +128,32 @@ class DescriptorPoolTestBase(object):
assert descriptor_pool.Default().FindFileContainingSymbol(
'protobuf_unittest.TestService')
# Can find field.
file_desc6 = self.pool.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory1Message.list_value')
self.assertIsInstance(file_desc6, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/internal/factory_test1.proto',
file_desc6.name)
# Can find top level Enum value.
file_desc7 = self.pool.FindFileContainingSymbol(
'google.protobuf.python.internal.FACTORY_1_VALUE_0')
self.assertIsInstance(file_desc7, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/internal/factory_test1.proto',
file_desc7.name)
# Can find nested Enum value.
file_desc8 = self.pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes.FOO')
self.assertIsInstance(file_desc8, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/unittest.proto',
file_desc8.name)
# TODO(jieluo): Add tests for no package when b/13860351 is fixed.
self.assertRaises(KeyError, self.pool.FindFileContainingSymbol,
'google.protobuf.python.internal.Factory1Message.none_field')
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
self.pool.FindFileContainingSymbol('Does not exist')
@ -217,11 +242,10 @@ class DescriptorPoolTestBase(object):
def testFindTypeErrors(self):
self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '')
self.assertRaises(KeyError, self.pool.FindMethodByName, '')
# TODO(jieluo): Fix python to raise correct errors.
if api_implementation.Type() == 'cpp':
self.assertRaises(TypeError, self.pool.FindMethodByName, 0)
self.assertRaises(KeyError, self.pool.FindMethodByName, '')
error_type = TypeError
else:
error_type = AttributeError
@ -231,6 +255,7 @@ class DescriptorPoolTestBase(object):
self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0)
self.assertRaises(error_type, self.pool.FindOneofByName, 0)
self.assertRaises(error_type, self.pool.FindServiceByName, 0)
self.assertRaises(error_type, self.pool.FindMethodByName, 0)
self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0)
if api_implementation.Type() == 'python':
error_type = KeyError
@ -275,11 +300,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindEnumTypeByName('Does not exist')
def testFindFieldByName(self):
if isinstance(self, SecondaryDescriptorFromDescriptorDB):
if api_implementation.Type() == 'cpp':
# TODO(jieluo): Fix cpp extension to find field correctly
# when descriptor pool is using an underlying database.
return
field = self.pool.FindFieldByName(
'google.protobuf.python.internal.Factory1Message.list_value')
self.assertEqual(field.name, 'list_value')
@ -290,11 +310,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindFieldByName('Does not exist')
def testFindOneofByName(self):
if isinstance(self, SecondaryDescriptorFromDescriptorDB):
if api_implementation.Type() == 'cpp':
# TODO(jieluo): Fix cpp extension to find oneof correctly
# when descriptor pool is using an underlying database.
return
oneof = self.pool.FindOneofByName(
'google.protobuf.python.internal.Factory2Message.oneof_field')
self.assertEqual(oneof.name, 'oneof_field')
@ -302,11 +317,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindOneofByName('Does not exist')
def testFindExtensionByName(self):
if isinstance(self, SecondaryDescriptorFromDescriptorDB):
if api_implementation.Type() == 'cpp':
# TODO(jieluo): Fix cpp extension to find extension correctly
# when descriptor pool is using an underlying database.
return
# An extension defined in a message.
extension = self.pool.FindExtensionByName(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@ -382,6 +392,11 @@ class DescriptorPoolTestBase(object):
with self.assertRaises(KeyError):
self.pool.FindServiceByName('Does not exist')
method = self.pool.FindMethodByName('protobuf_unittest.TestService.Foo')
self.assertIs(method.containing_service, service)
with self.assertRaises(KeyError):
self.pool.FindMethodByName('protobuf_unittest.TestService.Doesnotexist')
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
self.pool = descriptor_pool.DescriptorPool(db)
@ -601,6 +616,8 @@ class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase):
unittest_import_pb2.DESCRIPTOR.serialized_pb))
self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb))
self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
no_package_pb2.DESCRIPTOR.serialized_pb))
class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
@ -620,6 +637,8 @@ class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
unittest_import_pb2.DESCRIPTOR.serialized_pb))
db.Add(descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb))
db.Add(descriptor_pb2.FileDescriptorProto.FromString(
no_package_pb2.DESCRIPTOR.serialized_pb))
self.pool = descriptor_pool.DescriptorPool(descriptor_db=db)
@ -746,11 +765,7 @@ class MessageField(object):
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(field_type_desc, field_desc.message_type)
test.assertEqual(file_desc, field_desc.file)
# TODO(jieluo): Fix python and cpp extension diff for message field
# default value.
if api_implementation.Type() == 'cpp':
test.assertRaises(
NotImplementedError, getattr, field_desc, 'default_value')
test.assertEqual(field_desc.default_value, None)
class StringField(object):

View File

@ -452,6 +452,17 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual('attribute is not writable: has_options',
str(e.exception))
def testDefault(self):
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
field = message_descriptor.fields_by_name['repeated_int32']
self.assertEqual(field.default_value, [])
field = message_descriptor.fields_by_name['repeated_nested_message']
self.assertEqual(field.default_value, [])
field = message_descriptor.fields_by_name['optionalgroup']
self.assertEqual(field.default_value, None)
field = message_descriptor.fields_by_name['optional_nested_message']
self.assertEqual(field.default_value, None)
class NewDescriptorTest(DescriptorTest):
"""Redo the same tests as above, but with a separate DescriptorPool."""

View File

@ -56,3 +56,17 @@ message Factory1Message {
extensions 1000 to max;
}
message Factory1MethodRequest {
optional string argument = 1;
}
message Factory1MethodResponse {
optional string result = 1;
}
service Factory1Service {
// Dummy method for this dummy service.
rpc Factory1Method(Factory1MethodRequest) returns (Factory1MethodResponse) {
}
}

View File

@ -142,10 +142,8 @@ class MessageFactoryTest(unittest.TestCase):
self.assertEqual('test2', msg1.Extensions[ext2])
self.assertEqual(None,
msg1.Extensions._FindExtensionByNumber(12321))
self.assertRaises(TypeError, len, msg1.Extensions)
if api_implementation.Type() == 'cpp':
# TODO(jieluo): Fix len to return the correct value.
# self.assertEqual(2, len(msg1.Extensions))
self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
self.assertRaises(TypeError,
msg1.Extensions._FindExtensionByName, 0)
self.assertRaises(TypeError,

View File

@ -1,4 +1,5 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@ -49,6 +50,7 @@ import copy
import math
import operator
import pickle
import pydoc
import six
import sys
import warnings
@ -72,6 +74,7 @@ from google.protobuf import message_factory
from google.protobuf import text_format
from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
@ -415,6 +418,37 @@ class MessageTest(BaseTestCase):
empty.ParseFromString(populated.SerializeToString())
self.assertEqual(str(empty), '')
def testMergeFromRepeatedField(self, message_module):
msg = message_module.TestAllTypes()
msg.repeated_int32.append(1)
msg.repeated_int32.append(3)
msg.repeated_nested_message.add(bb=1)
msg.repeated_nested_message.add(bb=2)
other_msg = message_module.TestAllTypes()
other_msg.repeated_nested_message.add(bb=3)
other_msg.repeated_nested_message.add(bb=4)
other_msg.repeated_int32.append(5)
other_msg.repeated_int32.append(7)
msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
self.assertEqual(4, len(msg.repeated_int32))
msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
self.assertEqual([1, 2, 3, 4],
[m.bb for m in msg.repeated_nested_message])
def testAddWrongRepeatedNestedField(self, message_module):
msg = message_module.TestAllTypes()
try:
msg.repeated_nested_message.add('wrong')
except TypeError:
pass
try:
msg.repeated_nested_message.add(value_field='wrong')
except ValueError:
pass
self.assertEqual(len(msg.repeated_nested_message), 0)
def testRepeatedNestedFieldIteration(self, message_module):
msg = message_module.TestAllTypes()
msg.repeated_nested_message.add(bb=1)
@ -645,6 +679,82 @@ class MessageTest(BaseTestCase):
m.payload.repeated_int32.extend([])
self.assertTrue(m.HasField('payload'))
def testMergeFrom(self, message_module):
m1 = message_module.TestAllTypes()
m2 = message_module.TestAllTypes()
# Cpp extension will lazily create a sub message which is immutable.
self.assertEqual(0, m1.optional_nested_message.bb)
m2.optional_nested_message.bb = 1
# Make sure cmessage pointing to a mutable message after merge instead of
# the lazily created message.
m1.MergeFrom(m2)
self.assertEqual(1, m1.optional_nested_message.bb)
# Test more nested sub message.
msg1 = message_module.NestedTestAllTypes()
msg2 = message_module.NestedTestAllTypes()
self.assertEqual(0, msg1.child.payload.optional_nested_message.bb)
msg2.child.payload.optional_nested_message.bb = 1
msg1.MergeFrom(msg2)
self.assertEqual(1, msg1.child.payload.optional_nested_message.bb)
# Test repeated field.
self.assertEqual(msg1.payload.repeated_nested_message,
msg1.payload.repeated_nested_message)
msg2.payload.repeated_nested_message.add().bb = 1
msg1.MergeFrom(msg2)
self.assertEqual(1, len(msg1.payload.repeated_nested_message))
self.assertEqual(1, msg1.payload.repeated_nested_message[0].bb)
def testMergeFromString(self, message_module):
m1 = message_module.TestAllTypes()
m2 = message_module.TestAllTypes()
# Cpp extension will lazily create a sub message which is immutable.
self.assertEqual(0, m1.optional_nested_message.bb)
m2.optional_nested_message.bb = 1
# Make sure cmessage pointing to a mutable message after merge instead of
# the lazily created message.
m1.MergeFromString(m2.SerializeToString())
self.assertEqual(1, m1.optional_nested_message.bb)
@unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2')
def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module):
m2 = message_module.TestAllTypes()
m2.optional_string = 'scalar string'
m2.repeated_string.append('repeated string')
m2.optional_bytes = b'scalar bytes'
m2.repeated_bytes.append(b'repeated bytes')
serialized = m2.SerializeToString()
memview = memoryview(serialized)
m1 = message_module.TestAllTypes.FromString(memview)
self.assertEqual(m1.optional_bytes, b'scalar bytes')
self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
self.assertEqual(m1.optional_string, 'scalar string')
self.assertEqual(m1.repeated_string, ['repeated string'])
# Make sure that the memoryview was correctly converted to bytes, and
# that a sub-sliced memoryview is not being used.
self.assertIsInstance(m1.optional_bytes, bytes)
self.assertIsInstance(m1.repeated_bytes[0], bytes)
self.assertIsInstance(m1.optional_string, six.text_type)
self.assertIsInstance(m1.repeated_string[0], six.text_type)
@unittest.skipIf(six.PY3, 'memoryview is supported by py3')
def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module):
memview = memoryview(b'')
with self.assertRaises(TypeError):
message_module.TestAllTypes.FromString(memview)
def testMergeFromEmpty(self, message_module):
m1 = message_module.TestAllTypes()
# Cpp extension will lazily create a sub message which is immutable.
self.assertEqual(0, m1.optional_nested_message.bb)
self.assertFalse(m1.HasField('optional_nested_message'))
# Make sure the sub message is still immutable after merge from empty.
m1.MergeFromString(b'') # field state should not change
self.assertFalse(m1.HasField('optional_nested_message'))
def ensureNestedMessageExists(self, msg, attribute):
"""Make sure that a nested message object exists.
@ -1067,14 +1177,8 @@ class MessageTest(BaseTestCase):
with self.assertRaises(AttributeError):
m.repeated_int32 = []
m.repeated_int32.append(1)
if api_implementation.Type() == 'cpp':
# For test coverage: cpp has a different path if composite
# field is in cache
with self.assertRaises(TypeError):
m.repeated_int32 = []
else:
with self.assertRaises(AttributeError):
m.repeated_int32 = []
with self.assertRaises(AttributeError):
m.repeated_int32 = []
# Class to test proto2-only features (required, extensions, etc.)
@ -1169,6 +1273,21 @@ class Proto2Test(BaseTestCase):
msg = unittest_pb2.TestAllTypes()
self.assertRaises(AttributeError, getattr, msg, 'Extensions')
def testMergeFromExtensions(self):
msg1 = more_extensions_pb2.TopLevelMessage()
msg2 = more_extensions_pb2.TopLevelMessage()
# Cpp extension will lazily create a sub message which is immutable.
self.assertEqual(0, msg1.submessage.Extensions[
more_extensions_pb2.optional_int_extension])
self.assertFalse(msg1.HasField('submessage'))
msg2.submessage.Extensions[
more_extensions_pb2.optional_int_extension] = 123
# Make sure cmessage and extensions pointing to a mutable message
# after merge instead of the lazily created message.
msg1.MergeFrom(msg2)
self.assertEqual(123, msg1.submessage.Extensions[
more_extensions_pb2.optional_int_extension])
def testGoldenExtensions(self):
golden_data = test_util.GoldenFileData('golden_message')
golden_message = unittest_pb2.TestAllExtensions()
@ -1316,6 +1435,15 @@ class Proto2Test(BaseTestCase):
unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
def test_documentation(self):
# Also used by the interactive help() function.
doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
self.assertIn('class TestAllTypes', doc)
self.assertIn('SerializePartialToString', doc)
self.assertIn('repeated_float', doc)
base = unittest_pb2.TestAllTypes.__bases__[0]
self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
# Class to test proto3-only features/behavior (updated field presence & enums)
class Proto3Test(BaseTestCase):
@ -1539,10 +1667,8 @@ class Proto3Test(BaseTestCase):
self.assertEqual(True, msg2.map_bool_bool[True])
self.assertEqual(2, msg2.map_int32_enum[888])
self.assertEqual(456, msg2.map_int32_enum[123])
# TODO(jieluo): Add cpp extension support.
if api_implementation.Type() == 'python':
self.assertEqual('{-123: -456}',
str(msg2.map_int32_int32))
self.assertEqual('{-123: -456}',
str(msg2.map_int32_int32))
def testMapEntryAlwaysSerialized(self):
msg = map_unittest_pb2.TestMap()
@ -1603,11 +1729,10 @@ class Proto3Test(BaseTestCase):
self.assertIn(123, msg2.map_int32_foreign_message)
self.assertIn(-456, msg2.map_int32_foreign_message)
self.assertEqual(2, len(msg2.map_int32_foreign_message))
msg2.map_int32_foreign_message[123].c = 1
# TODO(jieluo): Fix text format for message map.
# TODO(jieluo): Add cpp extension support.
if api_implementation.Type() == 'python':
self.assertEqual(15,
len(str(msg2.map_int32_foreign_message)))
self.assertIn(str(msg2.map_int32_foreign_message),
('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
def testNestedMessageMapItemDelete(self):
msg = map_unittest_pb2.TestMap()
@ -1721,6 +1846,15 @@ class Proto3Test(BaseTestCase):
self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
# Test when cpp extension cache a map.
m1 = map_unittest_pb2.TestMap()
m2 = map_unittest_pb2.TestMap()
self.assertEqual(m1.map_int32_foreign_message,
m1.map_int32_foreign_message)
m2.map_int32_foreign_message[123].c = 10
m1.MergeFrom(m2)
self.assertEqual(10, m2.map_int32_foreign_message[123].c)
def testMergeFromBadType(self):
msg = map_unittest_pb2.TestMap()
with self.assertRaisesRegexp(
@ -1972,7 +2106,7 @@ class Proto3Test(BaseTestCase):
def testMapValidAfterFieldCleared(self):
# Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
# ScalarMapContainer::Release()
# MapContainer::Release()
msg = map_unittest_pb2.TestMap()
int32_map = msg.map_int32_int32
@ -1988,7 +2122,7 @@ class Proto3Test(BaseTestCase):
def testMessageMapValidAfterFieldCleared(self):
# Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
# ScalarMapContainer::Release()
# MapContainer::Release()
msg = map_unittest_pb2.TestMap()
int32_foreign_message = msg.map_int32_foreign_message
@ -1998,6 +2132,24 @@ class Proto3Test(BaseTestCase):
self.assertEqual(b'', msg.SerializeToString())
self.assertTrue(2 in int32_foreign_message.keys())
def testMessageMapItemValidAfterTopMessageCleared(self):
# Message map item needs to work even if it is cleared.
# For the C++ implementation this tests the correctness of
# MapContainer::Release()
msg = map_unittest_pb2.TestMap()
msg.map_int32_all_types[2].optional_string = 'bar'
if api_implementation.Type() == 'cpp':
# Need to keep the map reference because of b/27942626.
# TODO(jieluo): Remove it.
unused_map = msg.map_int32_all_types # pylint: disable=unused-variable
msg_value = msg.map_int32_all_types[2]
msg.Clear()
# Reset to trigger sync between repeated field and map in c++.
msg.map_int32_all_types[3].optional_string = 'foo'
self.assertEqual(msg_value.optional_string, 'bar')
def testMapIterInvalidatedByClearField(self):
# Map iterator is invalidated when field is cleared.
# But this case does need to not crash the interpreter.
@ -2058,6 +2210,80 @@ class Proto3Test(BaseTestCase):
msg.map_string_foreign_message['foo'].c = 5
self.assertEqual(0, len(msg.FindInitializationErrors()))
def testStrictUtf8Check(self):
# Test u'\ud801' is rejected at parser in both python2 and python3.
serialized = (b'r\x03\xed\xa0\x81')
msg = unittest_proto3_arena_pb2.TestAllTypes()
with self.assertRaises(Exception) as context:
msg.MergeFromString(serialized)
if api_implementation.Type() == 'python':
self.assertIn('optional_string', str(context.exception))
else:
self.assertIn('Error parsing message', str(context.exception))
# Test optional_string=u'😍' is accepted.
serialized = unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'😍').SerializeToString()
msg2 = unittest_proto3_arena_pb2.TestAllTypes()
msg2.MergeFromString(serialized)
self.assertEqual(msg2.optional_string, u'😍')
msg = unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'\ud001')
self.assertEqual(msg.optional_string, u'\ud001')
@unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2')
def testSurrogatesInPython3(self):
# Surrogates like U+D83D is an invalid unicode character, it is
# supported by Python2 only because in some builds, unicode strings
# use 2-bytes code units. Since Python 3.3, we don't have this problem.
#
# Surrogates are utf16 code units, in a unicode string they are invalid
# characters even when they appear in pairs like u'\ud801\udc01'. Protobuf
# Python3 reject such cases at setters and parsers. Python2 accpect it
# to keep same features with the language itself. 'Unpaired pairs'
# like u'\ud801' are rejected at parsers when strict utf8 check is enabled
# in proto3 to keep same behavior with c extension.
# Surrogates are rejected at setters in Python3.
with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'\ud801\udc01')
with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes(
optional_string=b'\xed\xa0\x81')
with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'\ud801')
with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'\ud801\ud801')
@unittest.skipIf(six.PY3, 'Surrogates are rejected at setters in Python3')
def testSurrogatesInPython2(self):
# Test optional_string=u'\ud801\udc01'.
# surrogate pair is acceptable in python2.
msg = unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'\ud801\udc01')
# TODO(jieluo): Change pure python to have same behavior with c extension.
# Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are
# equal, some are not equal.
if api_implementation.Type() == 'python':
self.assertEqual(msg.optional_string, u'\ud801\udc01')
else:
self.assertEqual(msg.optional_string, u'\U00010401')
serialized = msg.SerializeToString()
msg2 = unittest_proto3_arena_pb2.TestAllTypes()
msg2.MergeFromString(serialized)
self.assertEqual(msg2.optional_string, u'\U00010401')
# Python2 does not reject surrogates at setters.
msg = unittest_proto3_arena_pb2.TestAllTypes(
optional_string=b'\xed\xa0\x81')
unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'\ud801')
unittest_proto3_arena_pb2.TestAllTypes(
optional_string=u'\ud801\ud801')
class ValidTypeNamesTest(BaseTestCase):

View File

@ -1,3 +1,33 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
enum NoPackageEnum {

View File

@ -56,6 +56,7 @@ import sys
import weakref
import six
from six.moves import range
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import api_implementation
@ -124,6 +125,21 @@ class GeneratedProtocolMessageType(type):
Newly-allocated class.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
# If a concrete class already exists for this descriptor, don't try to
# create another. Doing so will break any messages that already exist with
# the existing class.
#
# The C++ implementation appears to have its own internal `PyMessageFactory`
# to achieve similar results.
#
# This most commonly happens in `text_format.py` when using descriptors from
# a custom pool; it calls symbol_database.Global().getPrototype() on a
# descriptor which already has an existing concrete class.
new_class = getattr(descriptor, '_concrete_class', None)
if new_class:
return new_class
if descriptor.full_name in well_known_types.WKTBASES:
bases += (well_known_types.WKTBASES[descriptor.full_name],)
_AddClassAttributesForNestedExtensions(descriptor, dictionary)
@ -151,6 +167,16 @@ class GeneratedProtocolMessageType(type):
type.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
# If this is an _existing_ class looked up via `_concrete_class` in the
# __new__ method above, then we don't need to re-initialize anything.
existing_class = getattr(descriptor, '_concrete_class', None)
if existing_class:
assert existing_class is cls, (
'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
% (descriptor.full_name))
return
cls._decoders_by_tag = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
@ -245,6 +271,7 @@ def _AddSlots(message_descriptor, dictionary):
'_cached_byte_size_dirty',
'_fields',
'_unknown_fields',
'_unknown_field_set',
'_is_present_in_parent',
'_listener',
'_listener_for_children',
@ -271,6 +298,13 @@ def _IsMessageMapField(field):
return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
def _IsStrictUtf8Check(field):
if field.containing_type.syntax != 'proto3':
return False
enforce_utf8 = True
return enforce_utf8
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
is_packable = (is_repeated and
@ -322,10 +356,16 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_decoder = decoder.MapDecoder(
field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
is_message_map)
elif decode_type == _FieldDescriptor.TYPE_STRING:
is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor)
field_decoder = decoder.StringDecoder(
field_descriptor.number, is_repeated, is_packed,
field_descriptor, field_descriptor._default_constructor,
is_strict_utf8_check)
else:
field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
field_descriptor.number, is_repeated, is_packed,
field_descriptor, field_descriptor._default_constructor)
field_descriptor.number, is_repeated, is_packed,
field_descriptor, field_descriptor._default_constructor)
cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
@ -422,6 +462,9 @@ def _DefaultValueConstructorForField(field):
# _concrete_class may not yet be initialized.
message_type = field.message_type
def MakeSubMessageDefault(message):
assert getattr(message_type, '_concrete_class', None), (
'Uninitialized concrete class found for field %r (message type %r)'
% (field.full_name, message_type.full_name))
result = message_type._concrete_class()
result._SetListener(
_OneofListener(message, field)
@ -477,6 +520,9 @@ def _AddInitMethod(message_descriptor, cls):
# _unknown_fields is () when empty for efficiency, and will be turned into
# a list if fields are added.
self._unknown_fields = ()
# _unknown_field_set is None when empty for efficiency, and will be
# turned into UnknownFieldSet struct if fields are added.
self._unknown_field_set = None # pylint: disable=protected-access
self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener()
self._listener_for_children = _Listener(self)
@ -584,6 +630,14 @@ def _AddPropertiesForField(field, cls):
_AddPropertiesForNonRepeatedScalarField(field, cls)
class _FieldProperty(property):
__slots__ = ('DESCRIPTOR',)
def __init__(self, descriptor, getter, setter, doc):
property.__init__(self, getter, setter, doc=doc)
self.DESCRIPTOR = descriptor
def _AddPropertiesForRepeatedField(field, cls):
"""Adds a public property for a "repeated" protocol message field. Clients
can use this property to get the value of the field, which will be either a
@ -625,7 +679,7 @@ def _AddPropertiesForRepeatedField(field, cls):
'"%s" in protocol message object.' % proto_field_name)
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
setattr(cls, property_name, property(getter, setter, doc=doc))
setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedScalarField(field, cls):
@ -681,7 +735,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
# Add a property to encapsulate the getter/setter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
setattr(cls, property_name, property(getter, setter, doc=doc))
setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedCompositeField(field, cls):
@ -725,7 +779,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
# Add a property to encapsulate the getter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
setattr(cls, property_name, property(getter, setter, doc=doc))
setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForExtensions(descriptor, cls):
@ -949,13 +1003,8 @@ def _AddEqualsMethod(message_descriptor, cls):
if not self.ListFields() == other.ListFields():
return False
# Sort unknown fields because their order shouldn't affect equality test.
unknown_fields = list(self._unknown_fields)
unknown_fields.sort()
other_unknown_fields = list(other._unknown_fields)
other_unknown_fields.sort()
return unknown_fields == other_unknown_fields
# pylint: disable=protected-access
return self._unknown_field_set == other._unknown_field_set
cls.__eq__ = __eq__
@ -1078,6 +1127,13 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
def _AddMergeFromStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def MergeFromString(self, serialized):
if isinstance(serialized, memoryview) and six.PY2:
raise TypeError(
'memoryview not supported in Python 2 with the pure Python proto '
'implementation: this is to maintain compatibility with the C++ '
'implementation')
serialized = memoryview(serialized)
length = len(serialized)
try:
if self._InternalParse(serialized, 0, length) != length:
@ -1095,26 +1151,54 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
"""Create a message from serialized bytes.
Args:
self: Message, instance of the proto message object.
buffer: memoryview of the serialized data.
pos: int, position to start in the serialized data.
end: int, end position of the serialized data.
Returns:
Message object.
"""
# Guard against internal misuse, since this function is called internally
# quite extensively, and its easy to accidentally pass bytes.
assert isinstance(buffer, memoryview)
self._Modified()
field_dict = self._fields
unknown_field_list = self._unknown_fields
# pylint: disable=protected-access
unknown_field_set = self._unknown_field_set
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
if field_decoder is None:
value_start_pos = new_pos
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if not self._unknown_fields: # pylint: disable=protected-access
self._unknown_fields = [] # pylint: disable=protected-access
if unknown_field_set is None:
# pylint: disable=protected-access
self._unknown_field_set = containers.UnknownFieldSet()
# pylint: disable=protected-access
unknown_field_set = self._unknown_field_set
# pylint: disable=protected-access
(tag, _) = decoder._DecodeVarint(tag_bytes, 0)
field_number, wire_type = wire_format.UnpackTag(tag)
# TODO(jieluo): remove old_pos.
old_pos = new_pos
(data, new_pos) = decoder._DecodeUnknownField(
buffer, new_pos, wire_type) # pylint: disable=protected-access
if new_pos == -1:
return pos
if (not is_proto3 or
api_implementation.GetPythonProto3PreserveUnknownsDefault()):
if not unknown_field_list:
unknown_field_list = self._unknown_fields = []
unknown_field_list.append(
(tag_bytes, buffer[value_start_pos:new_pos]))
# pylint: disable=protected-access
unknown_field_set._add(field_number, wire_type, data)
# TODO(jieluo): remove _unknown_fields.
new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
if new_pos == -1:
return pos
self._unknown_fields.append(
(tag_bytes, buffer[old_pos:new_pos].tobytes()))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@ -1259,6 +1343,10 @@ def _AddMergeFromMethod(cls):
if not self._unknown_fields:
self._unknown_fields = []
self._unknown_fields.extend(msg._unknown_fields)
# pylint: disable=protected-access
if self._unknown_field_set is None:
self._unknown_field_set = containers.UnknownFieldSet()
self._unknown_field_set._extend(msg._unknown_field_set)
cls.MergeFrom = MergeFrom
@ -1291,12 +1379,25 @@ def _Clear(self):
# Clear fields.
self._fields = {}
self._unknown_fields = ()
# pylint: disable=protected-access
if self._unknown_field_set is not None:
self._unknown_field_set._clear()
self._unknown_field_set = None
self._oneofs = {}
self._Modified()
def _UnknownFields(self):
if self._unknown_field_set is None: # pylint: disable=protected-access
# pylint: disable=protected-access
self._unknown_field_set = containers.UnknownFieldSet()
return self._unknown_field_set # pylint: disable=protected-access
def _DiscardUnknownFields(self):
self._unknown_fields = []
self._unknown_field_set = None # pylint: disable=protected-access
for field, value in self.ListFields():
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
@ -1335,6 +1436,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddReduceMethod(cls)
# Adds methods which do not depend on cls.
cls.Clear = _Clear
cls.UnknownFields = _UnknownFields
cls.DiscardUnknownFields = _DiscardUnknownFields
cls._SetListener = _SetListener
@ -1471,6 +1573,10 @@ class _ExtensionDict(object):
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
result = extension_handle._default_constructor(self._extended_message)
elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
assert getattr(extension_handle.message_type, '_concrete_class', None), (
'Uninitialized concrete class found for field %r (message type %r)'
% (extension_handle.full_name,
extension_handle.message_type.full_name))
result = extension_handle.message_type._concrete_class()
try:
result._SetListener(self._extended_message._listener_for_children)

View File

@ -64,6 +64,10 @@ from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import decoder
if six.PY3:
long = int # pylint: disable=redefined-builtin,invalid-name
BaseTestCase = testing_refleaks.BaseTestCase
@ -647,10 +651,7 @@ class ReflectionTest(BaseTestCase):
TestGetAndDeserialize('optional_int32', 1, int)
TestGetAndDeserialize('optional_int32', 1 << 30, int)
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
try:
integer_64 = long
except NameError: # Python3
integer_64 = int
integer_64 = long
if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32
# in an int.
@ -1103,6 +1104,7 @@ class ReflectionTest(BaseTestCase):
self.assertEqual(23, myproto_instance.foo_field)
self.assertTrue(myproto_instance.HasField('foo_field'))
@testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
def testDescriptorProtoSupport(self):
# Hand written descriptors/reflection are only supported by the pure-Python
# implementation of the API.
@ -1141,7 +1143,8 @@ class ReflectionTest(BaseTestCase):
self.assertTrue('price' in desc.fields_by_name)
self.assertTrue('owners' in desc.fields_by_name)
class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType,
message.Message)):
DESCRIPTOR = desc
prius = CarMessage()
@ -2435,7 +2438,7 @@ class SerializationTest(BaseTestCase):
first_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
serialized = memoryview(first_proto.SerializeToString())
for truncation_point in range(len(serialized) + 1):
try:
@ -2857,6 +2860,38 @@ class SerializationTest(BaseTestCase):
self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
51)
def testFieldProperties(self):
cls = unittest_pb2.TestAllTypes
self.assertIs(cls.optional_int32.DESCRIPTOR,
cls.DESCRIPTOR.fields_by_name['optional_int32'])
self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
cls.optional_int32.DESCRIPTOR.number)
self.assertIs(cls.optional_nested_message.DESCRIPTOR,
cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
cls.optional_nested_message.DESCRIPTOR.number)
self.assertIs(cls.repeated_int32.DESCRIPTOR,
cls.DESCRIPTOR.fields_by_name['repeated_int32'])
self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
cls.repeated_int32.DESCRIPTOR.number)
def testFieldDataDescriptor(self):
msg = unittest_pb2.TestAllTypes()
msg.optional_int32 = 42
self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
self.assertEqual(msg.optional_int32, 25)
with self.assertRaises(AttributeError):
del msg.optional_int32
try:
unittest_pb2.ForeignMessage.c.__get__(msg)
except TypeError:
pass # The cpp implementation cannot mix fields from other messages.
# This test exercises a specific check that avoids a crash.
else:
pass # The python implementation allows fields from other messages.
# This is useless, but works.
def testInitKwargs(self):
proto = unittest_pb2.TestAllTypes(
optional_int32=1,
@ -2963,6 +2998,7 @@ class ClassAPITest(BaseTestCase):
@unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation requires a call to MakeDescriptor()')
@testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
def testMakeClassWithNestedDescriptor(self):
leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
containing_type=None, fields=[],
@ -2980,10 +3016,7 @@ class ClassAPITest(BaseTestCase):
containing_type=None, fields=[],
nested_types=[child_desc, sibling_desc],
enum_types=[], extensions=[])
message_class = reflection.MakeClass(parent_desc)
self.assertIn('child', message_class.__dict__)
self.assertIn('sibling', message_class.__dict__)
self.assertIn('leaf', message_class.child.__dict__)
reflection.MakeClass(parent_desc)
def _GetSerializedFileDescriptor(self, name):
"""Get a serialized representation of a test FileDescriptorProto.

View File

@ -33,20 +33,19 @@
"""Test for google.protobuf.text_format."""
__author__ = 'kenton@google.com (Kenton Varda)'
import io
import math
import re
import six
import string
import textwrap
import six
# pylint: disable=g-import-not-at-top
try:
import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top
import unittest2 as unittest # PY26
except ImportError:
import unittest # pylint: disable=g-import-not-at-top
from google.protobuf.internal import _parameterized
import unittest
from google.protobuf import any_pb2
from google.protobuf import any_test_pb2
@ -54,12 +53,13 @@ from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import any_test_pb2 as test_extend_any
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import test_util
from google.protobuf import descriptor_pool
from google.protobuf import text_format
from google.protobuf.internal import _parameterized
# pylint: enable=g-import-not-at-top
# Low-level nuts-n-bolts tests.
@ -100,8 +100,8 @@ class TextFormatBase(unittest.TestCase):
return text
@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))
class TextFormatTest(TextFormatBase):
@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
class TextFormatMessageToStringTests(TextFormatBase):
def testPrintExotic(self, message_module):
message = message_module.TestAllTypes()
@ -154,6 +154,40 @@ class TextFormatTest(TextFormatBase):
'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
'repeated_string: "Google" repeated_string: "Zurich"')
def VerifyPrintShortFormatRepeatedFields(self, message_module, as_one_line):
message = message_module.TestAllTypes()
message.repeated_int32.append(1)
message.repeated_string.append('Google')
message.repeated_string.append('Hello,World')
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_FOO)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
message.optional_nested_message.bb = 3
for i in (21, 32):
msg = message.repeated_nested_message.add()
msg.bb = i
expected_ascii = (
'optional_nested_message {\n bb: 3\n}\n'
'repeated_int32: [1]\n'
'repeated_string: "Google"\n'
'repeated_string: "Hello,World"\n'
'repeated_nested_message {\n bb: 21\n}\n'
'repeated_nested_message {\n bb: 32\n}\n'
'repeated_foreign_enum: [FOREIGN_FOO, FOREIGN_BAR, FOREIGN_BAZ]\n')
if as_one_line:
expected_ascii = expected_ascii.replace('\n ', '').replace('\n', '')
actual_ascii = text_format.MessageToString(
message, use_short_repeated_primitives=True,
as_one_line=as_one_line)
self.CompareToGoldenText(actual_ascii, expected_ascii)
parsed_message = message_module.TestAllTypes()
text_format.Parse(actual_ascii, parsed_message)
self.assertEqual(parsed_message, message)
def tesPrintShortFormatRepeatedFields(self, message_module, as_one_line):
self.VerifyPrintShortFormatRepeatedFields(message_module, False)
self.VerifyPrintShortFormatRepeatedFields(message_module, True)
def testPrintNestedNewLineInStringAsOneLine(self, message_module):
message = message_module.TestAllTypes()
message.optional_string = 'a\nnew\nline'
@ -213,13 +247,18 @@ class TextFormatTest(TextFormatBase):
def testPrintRawUtf8String(self, message_module):
message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\ua71f')
message.repeated_string.append(u'\u00fc\t\ua71f')
text = text_format.MessageToString(message, as_utf8=True)
self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
golden_unicode = u'repeated_string: "\u00fc\\t\ua71f"\n'
golden_text = golden_unicode if six.PY3 else golden_unicode.encode('utf-8')
# MessageToString always returns a native str.
self.CompareToGoldenText(text, golden_text)
parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
self.assertEqual(message, parsed_message,
'\n%s != %s' % (message, parsed_message))
self.assertEqual(
message, parsed_message, '\n%s != %s (%s != %s)' %
(message, parsed_message, message.repeated_string[0],
parsed_message.repeated_string[0]))
def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
@ -259,6 +298,36 @@ class TextFormatTest(TextFormatBase):
message.c = 123
self.assertEqual('c: 123\n', str(message))
def testMessageToStringUnicode(self, message_module):
golden_unicode = u'Á short desçription and a 🍌.'
golden_bytes = golden_unicode.encode('utf-8')
message = message_module.TestAllTypes()
message.optional_string = golden_unicode
message.optional_bytes = golden_bytes
text = text_format.MessageToString(message, as_utf8=True)
golden_message = textwrap.dedent(
'optional_string: "Á short desçription and a 🍌."\n'
'optional_bytes: '
r'"\303\201 short des\303\247ription and a \360\237\215\214."'
'\n')
self.CompareToGoldenText(text, golden_message)
def testMessageToStringASCII(self, message_module):
golden_unicode = u'Á short desçription and a 🍌.'
golden_bytes = golden_unicode.encode('utf-8')
message = message_module.TestAllTypes()
message.optional_string = golden_unicode
message.optional_bytes = golden_bytes
text = text_format.MessageToString(message, as_utf8=False) # ASCII
golden_message = (
'optional_string: '
r'"\303\201 short des\303\247ription and a \360\237\215\214."'
'\n'
'optional_bytes: '
r'"\303\201 short des\303\247ription and a \360\237\215\214."'
'\n')
self.CompareToGoldenText(text, golden_message)
def testPrintField(self, message_module):
message = message_module.TestAllTypes()
field = message.DESCRIPTOR.fields_by_name['optional_float']
@ -289,6 +358,45 @@ class TextFormatTest(TextFormatBase):
self.assertEqual('0.0', out.getvalue())
out.close()
@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
class TextFormatMessageToTextBytesTests(TextFormatBase):
def testMessageToBytes(self, message_module):
message = message_module.ForeignMessage()
message.c = 123
self.assertEqual(b'c: 123\n', text_format.MessageToBytes(message))
def testRawUtf8RoundTrip(self, message_module):
message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\t\ua71f')
utf8_text = text_format.MessageToBytes(message, as_utf8=True)
golden_bytes = b'repeated_string: "\xc3\xbc\\t\xea\x9c\x9f"\n'
self.CompareToGoldenText(utf8_text, golden_bytes)
parsed_message = message_module.TestAllTypes()
text_format.Parse(utf8_text, parsed_message)
self.assertEqual(
message, parsed_message, '\n%s != %s (%s != %s)' %
(message, parsed_message, message.repeated_string[0],
parsed_message.repeated_string[0]))
def testEscapedUtf8ASCIIRoundTrip(self, message_module):
message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\t\ua71f')
ascii_text = text_format.MessageToBytes(message) # as_utf8=False default
golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n'
self.CompareToGoldenText(ascii_text, golden_bytes)
parsed_message = message_module.TestAllTypes()
text_format.Parse(ascii_text, parsed_message)
self.assertEqual(
message, parsed_message, '\n%s != %s (%s != %s)' %
(message, parsed_message, message.repeated_string[0],
parsed_message.repeated_string[0]))
@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
class TextFormatParserTests(TextFormatBase):
def testParseAllFields(self, message_module):
message = message_module.TestAllTypes()
test_util.SetAllFields(message)
@ -318,14 +426,14 @@ class TextFormatTest(TextFormatBase):
if message_module is unittest_pb2:
test_util.ExpectAllFieldsSet(self, message)
if six.PY2:
msg2 = message_module.TestAllTypes()
text = (u'optional_string: "café"')
text_format.Merge(text, msg2)
self.assertEqual(msg2.optional_string, u'café')
msg2.Clear()
text_format.Parse(text, msg2)
self.assertEqual(msg2.optional_string, u'café')
msg2 = message_module.TestAllTypes()
text = (u'optional_string: "café"')
text_format.Merge(text, msg2)
self.assertEqual(msg2.optional_string, u'café')
msg2.Clear()
self.assertEqual(msg2.optional_string, u'')
text_format.Parse(text, msg2)
self.assertEqual(msg2.optional_string, u'café')
def testParseExotic(self, message_module):
message = message_module.TestAllTypes()
@ -425,7 +533,8 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
six.assertRaisesRegex(self, text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
(r'1:23 : \'optional_nested_enum: BARR\': '
r'Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value named BARR.'), text_format.Parse,
text, message)
@ -433,7 +542,8 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_int32: bork'
six.assertRaisesRegex(self, text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'),
('1:17 : \'optional_int32: bork\': '
'Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
def testParseStringFieldUnescape(self, message_module):
@ -457,6 +567,96 @@ class TextFormatTest(TextFormatBase):
message.repeated_string[4])
self.assertEqual(SLASH + 'x20', message.repeated_string[5])
def testParseOneof(self, message_module):
m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m2 = message_module.TestAllTypes()
text_format.Parse(text_format.MessageToString(m), m2)
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
def testParseMultipleOneof(self, message_module):
m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
m2 = message_module.TestAllTypes()
with six.assertRaisesRegex(self, text_format.ParseError,
' is specified along with field '):
text_format.Parse(m_string, m2)
# This example contains non-ASCII codepoint unicode data as literals
# which should come through as utf-8 for bytes, and as the unicode
# itself for string fields. It also demonstrates escaped binary data.
# The ur"" string prefix is unfortunately missing from Python 3
# so we resort to double escaping our \s so that they come through.
_UNICODE_SAMPLE = u"""
optional_bytes: 'Á short desçription'
optional_string: 'Á short desçription'
repeated_bytes: '\\303\\201 short des\\303\\247ription'
repeated_bytes: '\\x12\\x34\\x56\\x78\\x90\\xab\\xcd\\xef'
repeated_string: '\\xd0\\x9f\\xd1\\x80\\xd0\\xb8\\xd0\\xb2\\xd0\\xb5\\xd1\\x82'
"""
_BYTES_SAMPLE = _UNICODE_SAMPLE.encode('utf-8')
_GOLDEN_UNICODE = u'Á short desçription'
_GOLDEN_BYTES = _GOLDEN_UNICODE.encode('utf-8')
_GOLDEN_BYTES_1 = b'\x12\x34\x56\x78\x90\xab\xcd\xef'
_GOLDEN_STR_0 = u'Привет'
def testParseUnicode(self, message_module):
m = message_module.TestAllTypes()
text_format.Parse(self._UNICODE_SAMPLE, m)
self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
# repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data.
self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1)
# repeated_string[0] contained \ escaped data representing the UTF-8
# representation of _GOLDEN_STR_0 - it needs to decode as such.
self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0)
def testParseBytes(self, message_module):
m = message_module.TestAllTypes()
text_format.Parse(self._BYTES_SAMPLE, m)
self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
# repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data.
self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1)
# repeated_string[0] contained \ escaped data representing the UTF-8
# representation of _GOLDEN_STR_0 - it needs to decode as such.
self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0)
def testFromBytesFile(self, message_module):
m = message_module.TestAllTypes()
f = io.BytesIO(self._BYTES_SAMPLE)
text_format.ParseLines(f, m)
self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
def testFromUnicodeFile(self, message_module):
m = message_module.TestAllTypes()
f = io.StringIO(self._UNICODE_SAMPLE)
text_format.ParseLines(f, m)
self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
def testFromBytesLines(self, message_module):
m = message_module.TestAllTypes()
text_format.ParseLines(self._BYTES_SAMPLE.split(b'\n'), m)
self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
def testFromUnicodeLines(self, message_module):
m = message_module.TestAllTypes()
text_format.ParseLines(self._UNICODE_SAMPLE.split(u'\n'), m)
self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
class TextFormatMergeTests(TextFormatBase):
def testMergeDuplicateScalars(self, message_module):
message = message_module.TestAllTypes()
text = ('optional_int32: 42 ' 'optional_int32: 67')
@ -472,26 +672,12 @@ class TextFormatTest(TextFormatBase):
self.assertTrue(r is message)
self.assertEqual(2, message.optional_nested_message.bb)
def testParseOneof(self, message_module):
m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m2 = message_module.TestAllTypes()
text_format.Parse(text_format.MessageToString(m), m2)
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
def testMergeMultipleOneof(self, message_module):
m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
m2 = message_module.TestAllTypes()
text_format.Merge(m_string, m2)
self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
def testParseMultipleOneof(self, message_module):
m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
m2 = message_module.TestAllTypes()
with self.assertRaisesRegexp(text_format.ParseError,
' is specified along with field '):
text_format.Parse(m_string, m2)
# These are tests that aren't fundamentally specific to proto2, but are at
# the moment because of differences between the proto2 and proto3 test schemas.
@ -938,7 +1124,7 @@ class Proto2Tests(TextFormatBase):
'}\n')
six.assertRaisesRegex(self,
text_format.ParseError,
'5:1 : Expected ">".',
'5:1 : \'}\': Expected ">".',
text_format.Parse,
malformed,
message,
@ -981,7 +1167,8 @@ class Proto2Tests(TextFormatBase):
with self.assertRaises(text_format.ParseError) as e:
text_format.Parse(text, message)
self.assertEqual(str(e.exception),
'1:27 : Expected identifier or number, got "bb".')
'1:27 : \'optional_nested_message { "bb": 1 }\': '
'Expected identifier or number, got "bb".')
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
@ -998,7 +1185,8 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: 100'
six.assertRaisesRegex(self, text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
(r'1:23 : \'optional_nested_enum: 100\': '
r'Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value with number 100.'), text_format.Parse,
text, message)
@ -1448,6 +1636,26 @@ class TokenizerTest(unittest.TestCase):
self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
self.assertTrue(tokenizer.AtEnd())
def testConsumeOctalIntegers(self):
"""Test support for C style octal integers."""
text = '00 -00 04 0755 -010 007 -0033 08 -09 01'
tokenizer = text_format.Tokenizer(text.splitlines())
self.assertEqual(0, tokenizer.ConsumeInteger())
self.assertEqual(0, tokenizer.ConsumeInteger())
self.assertEqual(4, tokenizer.ConsumeInteger())
self.assertEqual(0o755, tokenizer.ConsumeInteger())
self.assertEqual(-0o10, tokenizer.ConsumeInteger())
self.assertEqual(7, tokenizer.ConsumeInteger())
self.assertEqual(-0o033, tokenizer.ConsumeInteger())
with self.assertRaises(text_format.ParseError):
tokenizer.ConsumeInteger() # 08
tokenizer.NextToken()
with self.assertRaises(text_format.ParseError):
tokenizer.ConsumeInteger() # -09
tokenizer.NextToken()
self.assertEqual(1, tokenizer.ConsumeInteger())
self.assertTrue(tokenizer.AtEnd())
def testConsumeByteString(self):
text = '"string1\''
tokenizer = text_format.Tokenizer(text.splitlines())
@ -1556,6 +1764,12 @@ class TokenizerTest(unittest.TestCase):
tokenizer.ConsumeCommentOrTrailingComment())
self.assertTrue(tokenizer.AtEnd())
def testHugeString(self):
# With pathologic backtracking, fails with Forge OOM.
text = '"' + 'a' * (10 * 1024 * 1024) + '"'
tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
tokenizer.ConsumeString()
# Tests for pretty printer functionality.
@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))

View File

@ -185,6 +185,14 @@ class UnicodeValueChecker(object):
'encoding. Non-UTF-8 strings must be converted to '
'unicode objects before being added.' %
(proposed_value))
else:
try:
proposed_value.encode('utf8')
except UnicodeEncodeError:
raise ValueError('%.1024r isn\'t a valid unicode string and '
'can\'t be encoded in UTF-8.'%
(proposed_value))
return proposed_value
def DefaultValue(self):

View File

@ -49,20 +49,12 @@ from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
BaseTestCase = testing_refleaks.BaseTestCase
# CheckUnknownField() cannot be used by the C++ implementation because
# some protect members are called. It is not a behavior difference
# for python and C++ implementation.
def SkipCheckUnknownFieldIfCppImplementation(func):
return unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'Addtional test for pure python involved protect members')(func)
class UnknownFieldsTest(BaseTestCase):
def setUp(self):
@ -80,23 +72,11 @@ class UnknownFieldsTest(BaseTestCase):
# stdout.
self.assertTrue(data == self.all_fields_data)
def expectSerializeProto3(self, preserve):
def testSerializeProto3(self):
# Verify proto3 unknown fields behavior.
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
if preserve:
self.assertEqual(self.all_fields_data, message.SerializeToString())
else:
self.assertEqual(0, len(message.SerializeToString()))
def testSerializeProto3(self):
# Verify that proto3 unknown fields behavior.
default_preserve = (api_implementation
.GetPythonProto3PreserveUnknownsDefault())
self.expectSerializeProto3(default_preserve)
api_implementation.SetPythonProto3PreserveUnknownsDefault(
not default_preserve)
self.expectSerializeProto3(not default_preserve)
api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve)
self.assertEqual(self.all_fields_data, message.SerializeToString())
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
@ -169,13 +149,15 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
# CheckUnknownField() is an additional Pure Python check which checks
# InternalCheckUnknownField() is an additional Pure Python check which checks
# a detail of unknown fields. It cannot be used by the C++
# implementation because some protect members are called.
# The test is added for historical reasons. It is not necessary as
# serialized string is checked.
def CheckUnknownField(self, name, expected_value):
# TODO(jieluo): Remove message._unknown_fields.
def InternalCheckUnknownField(self, name, expected_value):
if api_implementation.Type() == 'cpp':
return
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
@ -183,36 +165,80 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
decoder(value, 0, len(value), self.all_fields, result_dict)
decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
self.assertEqual(expected_value, result_dict[field_descriptor])
@SkipCheckUnknownFieldIfCppImplementation
def CheckUnknownField(self, name, unknown_fields, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
field_descriptor.type]
for unknown_field in unknown_fields:
if unknown_field.field_number == field_descriptor.number:
self.assertEqual(expected_type, unknown_field.wire_type)
if expected_type == 3:
# Check group
self.assertEqual(expected_value[0],
unknown_field.data[0].field_number)
self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
self.assertEqual(expected_value[2], unknown_field.data[0].data)
continue
if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
self.assertIn(unknown_field.data, expected_value)
else:
self.assertEqual(expected_value, unknown_field.data)
def testCheckUnknownFieldValue(self):
unknown_fields = self.empty_message.UnknownFields()
# Test enum.
self.CheckUnknownField('optional_nested_enum',
unknown_fields,
self.all_fields.optional_nested_enum)
self.InternalCheckUnknownField('optional_nested_enum',
self.all_fields.optional_nested_enum)
# Test repeated enum.
self.CheckUnknownField('repeated_nested_enum',
unknown_fields,
self.all_fields.repeated_nested_enum)
self.InternalCheckUnknownField('repeated_nested_enum',
self.all_fields.repeated_nested_enum)
# Test varint.
self.CheckUnknownField('optional_int32',
unknown_fields,
self.all_fields.optional_int32)
self.InternalCheckUnknownField('optional_int32',
self.all_fields.optional_int32)
# Test fixed32.
self.CheckUnknownField('optional_fixed32',
unknown_fields,
self.all_fields.optional_fixed32)
self.InternalCheckUnknownField('optional_fixed32',
self.all_fields.optional_fixed32)
# Test fixed64.
self.CheckUnknownField('optional_fixed64',
unknown_fields,
self.all_fields.optional_fixed64)
self.InternalCheckUnknownField('optional_fixed64',
self.all_fields.optional_fixed64)
# Test lengthd elimited.
self.CheckUnknownField('optional_string',
self.all_fields.optional_string)
unknown_fields,
self.all_fields.optional_string.encode('utf-8'))
self.InternalCheckUnknownField('optional_string',
self.all_fields.optional_string)
# Test group.
self.CheckUnknownField('optionalgroup',
self.all_fields.optionalgroup)
unknown_fields,
(17, 0, 117))
self.InternalCheckUnknownField('optionalgroup',
self.all_fields.optionalgroup)
self.assertEqual(97, len(unknown_fields))
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
@ -230,9 +256,18 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
message.optional_int64 = 3
message.optional_uint32 = 4
destination = unittest_pb2.TestEmptyMessage()
unknown_fields = destination.UnknownFields()
self.assertEqual(0, len(unknown_fields))
destination.ParseFromString(message.SerializeToString())
# ParseFromString clears the message thus unknown fields is invalid.
with self.assertRaises(ValueError) as context:
len(unknown_fields)
self.assertIn('UnknownFields does not exist.',
str(context.exception))
unknown_fields = destination.UnknownFields()
self.assertEqual(2, len(unknown_fields))
destination.MergeFrom(source)
self.assertEqual(4, len(unknown_fields))
# Check that the fields where correctly merged, even stored in the unknown
# fields set.
message.ParseFromString(destination.SerializeToString())
@ -241,9 +276,58 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.assertEqual(message.optional_int64, 3)
def testClear(self):
unknown_fields = self.empty_message.UnknownFields()
self.empty_message.Clear()
# All cleared, even unknown fields.
self.assertEqual(self.empty_message.SerializeToString(), b'')
with self.assertRaises(ValueError) as context:
len(unknown_fields)
self.assertIn('UnknownFields does not exist.',
str(context.exception))
def testSubUnknownFields(self):
message = unittest_pb2.TestAllTypes()
message.optionalgroup.a = 123
destination = unittest_pb2.TestEmptyMessage()
destination.ParseFromString(message.SerializeToString())
sub_unknown_fields = destination.UnknownFields()[0].data
self.assertEqual(1, len(sub_unknown_fields))
self.assertEqual(sub_unknown_fields[0].data, 123)
destination.Clear()
with self.assertRaises(ValueError) as context:
len(sub_unknown_fields)
self.assertIn('UnknownFields does not exist.',
str(context.exception))
with self.assertRaises(ValueError) as context:
# pylint: disable=pointless-statement
sub_unknown_fields[0]
self.assertIn('UnknownFields does not exist.',
str(context.exception))
message.Clear()
message.optional_uint32 = 456
nested_message = unittest_pb2.NestedTestAllTypes()
nested_message.payload.optional_nested_message.ParseFromString(
message.SerializeToString())
unknown_fields = (
nested_message.payload.optional_nested_message.UnknownFields())
self.assertEqual(unknown_fields[0].data, 456)
nested_message.ClearField('payload')
self.assertEqual(unknown_fields[0].data, 456)
unknown_fields = (
nested_message.payload.optional_nested_message.UnknownFields())
self.assertEqual(0, len(unknown_fields))
def testUnknownField(self):
message = unittest_pb2.TestAllTypes()
message.optional_int32 = 123
destination = unittest_pb2.TestEmptyMessage()
destination.ParseFromString(message.SerializeToString())
unknown_field = destination.UnknownFields()[0]
destination.Clear()
with self.assertRaises(ValueError) as context:
unknown_field.data # pylint: disable=pointless-statement
self.assertIn('The parent message might be cleared.',
str(context.exception))
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
@ -280,15 +364,13 @@ class UnknownEnumValuesTest(BaseTestCase):
def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
result_dict = {}
for tag_bytes, value in self.missing_message._unknown_fields:
if tag_bytes == field_tag:
decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
tag_bytes][0]
decoder(value, 0, len(value), self.message, result_dict)
self.assertEqual(expected_value, result_dict[field_descriptor])
unknown_fields = self.missing_message.UnknownFields()
for field in unknown_fields:
if field.field_number == field_descriptor.number:
if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
self.assertIn(field.data, expected_value)
else:
self.assertEqual(expected_value, field.data)
def testUnknownParseMismatchEnumValue(self):
just_string = missing_enum_values_pb2.JustString()
@ -317,7 +399,6 @@ class UnknownEnumValuesTest(BaseTestCase):
def testUnknownPackedEnumValue(self):
self.assertEqual([], self.missing_message.packed_nested_enum)
@SkipCheckUnknownFieldIfCppImplementation
def testCheckUnknownFieldValueForEnum(self):
self.CheckUnknownField('optional_nested_enum',
self.message.optional_nested_enum)

View File

@ -482,7 +482,7 @@ class _Parser(object):
('Message type "{0}" has no field named "{1}".\n'
' Available Fields(except extensions): {2}').format(
message_descriptor.full_name, name,
message_descriptor.fields))
[f.json_name for f in message_descriptor.fields]))
if name in names:
raise ParseError('Message type "{0}" should not have multiple '
'"{1}" fields.'.format(

View File

@ -268,6 +268,10 @@ class Message(object):
def ClearExtension(self, extension_handle):
raise NotImplementedError
def UnknownFields(self):
"""Returns the UnknownFieldSet."""
raise NotImplementedError
def DiscardUnknownFields(self):
raise NotImplementedError

View File

@ -39,9 +39,18 @@ my_proto_instance = message_classes['some.proto.package.MessageName']()
__author__ = 'matthewtoia@google.com (Matt Toia)'
from google.protobuf.internal import api_implementation
from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import reflection
if api_implementation.Type() == 'cpp':
from google.protobuf.pyext import cpp_message as message_impl
else:
from google.protobuf.internal import python_message as message_impl
# The type of all Message classes.
_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType
class MessageFactory(object):
@ -70,11 +79,11 @@ class MessageFactory(object):
descriptor_name = descriptor.name
if str is bytes: # PY2
descriptor_name = descriptor.name.encode('ascii', 'ignore')
result_class = reflection.GeneratedProtocolMessageType(
result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
descriptor_name,
(message.Message,),
{'DESCRIPTOR': descriptor, '__module__': None})
# If module not set, it wrongly points to the reflection.py module.
# If module not set, it wrongly points to message_factory module.
self._classes[descriptor] = result_class
for field in descriptor.fields:
if field.message_type:

View File

@ -42,16 +42,15 @@
// Then use the methods of the returned class:
// py_proto_api->GetMessagePointer(...);
#ifndef PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
#define PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
#include <Python.h>
#include <google/protobuf/message.h>
namespace google {
namespace protobuf {
class Message;
namespace python {
// Note on the implementation:
@ -89,4 +88,4 @@ inline const char* PyProtoAPICapsuleName() {
} // namespace protobuf
} // namespace google
#endif // PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
#endif // GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__

View File

@ -32,8 +32,8 @@
#include <Python.h>
#include <frameobject.h>
#include <google/protobuf/stubs/hash.h>
#include <string>
#include <unordered_map>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/descriptor.pb.h>
@ -44,6 +44,7 @@
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/hash.h>
#if PY_MAJOR_VERSION >= 3
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
@ -54,10 +55,12 @@
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#endif
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob)? \
((*(charpp) = const_cast<char*>(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \
PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
? -1 \
: 0) \
: PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
@ -70,7 +73,7 @@ namespace python {
// released.
// This is enough to support the "is" operator on live objects.
// All descriptors are stored here.
hash_map<const void*, PyObject*> interned_descriptors;
std::unordered_map<const void*, PyObject*>* interned_descriptors;
PyObject* PyString_FromCppString(const string& str) {
return PyString_FromStringAndSize(str.c_str(), str.size());
@ -119,8 +122,10 @@ bool _CalledFromGeneratedFile(int stacklevel) {
PyErr_Clear();
return false;
}
if ((filename_size < 3) || (strcmp(&filename[filename_size - 3], ".py") != 0)) {
// Cython's stack does not have .py file name and is not at global module scope.
if ((filename_size < 3) ||
(strcmp(&filename[filename_size - 3], ".py") != 0)) {
// Cython's stack does not have .py file name and is not at global module
// scope.
return true;
}
if (filename_size < 7) {
@ -131,7 +136,7 @@ bool _CalledFromGeneratedFile(int stacklevel) {
// Filename is not ending with _pb2.
return false;
}
if (frame->f_globals != frame->f_locals) {
// Not at global module scope
return false;
@ -197,7 +202,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
// First search in the cache.
PyDescriptorPool* caching_pool = GetDescriptorPool_FromPool(
GetFileDescriptor(descriptor)->pool());
hash_map<const void*, PyObject*>* descriptor_options =
std::unordered_map<const void*, PyObject*>* descriptor_options =
caching_pool->descriptor_options;
if (descriptor_options->find(descriptor) != descriptor_options->end()) {
PyObject *value = (*descriptor_options)[descriptor];
@ -232,7 +237,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
if (value == NULL) {
return NULL;
}
if (!PyObject_TypeCheck(value.get(), &CMessage_Type)) {
if (!PyObject_TypeCheck(value.get(), CMessage_Type)) {
PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s",
message_type->full_name().c_str(),
Py_TYPE(value.get())->tp_name);
@ -275,7 +280,7 @@ static PyObject* CopyToPythonProto(const DescriptorClass *descriptor,
const Descriptor* self_descriptor =
DescriptorProtoClass::default_instance().GetDescriptor();
CMessage* message = reinterpret_cast<CMessage*>(target);
if (!PyObject_TypeCheck(target, &CMessage_Type) ||
if (!PyObject_TypeCheck(target, CMessage_Type) ||
message->message->GetDescriptor() != self_descriptor) {
PyErr_Format(PyExc_TypeError, "Not a %s message",
self_descriptor->full_name().c_str());
@ -332,9 +337,9 @@ PyObject* NewInternedDescriptor(PyTypeObject* type,
}
// See if the object is in the map of interned descriptors
hash_map<const void*, PyObject*>::iterator it =
interned_descriptors.find(descriptor);
if (it != interned_descriptors.end()) {
std::unordered_map<const void*, PyObject*>::iterator it =
interned_descriptors->find(descriptor);
if (it != interned_descriptors->end()) {
GOOGLE_DCHECK(Py_TYPE(it->second) == type);
Py_INCREF(it->second);
return it->second;
@ -348,7 +353,7 @@ PyObject* NewInternedDescriptor(PyTypeObject* type,
py_descriptor->descriptor = descriptor;
// and cache it.
interned_descriptors.insert(
interned_descriptors->insert(
std::make_pair(descriptor, reinterpret_cast<PyObject*>(py_descriptor)));
// Ensures that the DescriptorPool stays alive.
@ -370,7 +375,7 @@ PyObject* NewInternedDescriptor(PyTypeObject* type,
static void Dealloc(PyBaseDescriptor* self) {
// Remove from interned dictionary
interned_descriptors.erase(self->descriptor);
interned_descriptors->erase(self->descriptor);
Py_CLEAR(self->pool);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@ -758,6 +763,11 @@ static PyObject* HasDefaultValue(PyBaseDescriptor *self, void *closure) {
static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) {
PyObject *result;
if (_GetDescriptor(self)->is_repeated()) {
return PyList_New(0);
}
switch (_GetDescriptor(self)->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
int32 value = _GetDescriptor(self)->default_value_int32();
@ -805,6 +815,10 @@ static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) {
result = PyInt_FromLong(value->number());
break;
}
case FieldDescriptor::CPPTYPE_MESSAGE: {
Py_RETURN_NONE;
break;
}
default:
PyErr_Format(PyExc_NotImplementedError, "default value for %s",
_GetDescriptor(self)->full_name().c_str());
@ -1919,6 +1933,9 @@ bool InitDescriptor() {
if (!InitDescriptorMappingTypes())
return false;
// Initialize globals defined in this file.
interned_descriptors = new std::unordered_map<const void*, PyObject*>;
return true;
}

View File

@ -100,6 +100,6 @@ bool InitDescriptor();
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__

Some files were not shown because too many files have changed in this diff Show More