Blob Blame History Raw
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef tls_filter_h_
#define tls_filter_h_

#include <memory>
#include <vector>

#include "test_io.h"
#include "tls_parser.h"

namespace nss_test {

// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
 public:
  TlsRecordFilter() : count_(0) {}

  virtual bool Filter(const DataBuffer& input, DataBuffer* output);

  // Report how many packets were altered by the filter.
  size_t filtered_packets() const { return count_; }

 protected:
  virtual bool FilterRecord(uint8_t content_type, uint16_t version,
                            const DataBuffer& data, DataBuffer* changed) = 0;
 private:
  size_t ApplyFilter(uint8_t content_type, uint16_t version,
                     const DataBuffer& record, DataBuffer* output,
                     size_t offset, bool* changed);

  size_t count_;
};

// Abstract filter that operates on handshake messages rather than records.
// This assumes that the handshake messages are written in a block as entire
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
 public:
  TlsHandshakeFilter() {}

  // Reads the length from the record header.
  // This also reads the DTLS fragment information and checks it.
  static bool ReadLength(TlsParser* parser, uint16_t version, uint32_t *length);

 protected:
  virtual bool FilterRecord(uint8_t content_type, uint16_t version,
                            const DataBuffer& input, DataBuffer* output);
  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
                               const DataBuffer& input, DataBuffer* output) = 0;

 private:
  size_t ApplyFilter(uint16_t version, uint8_t handshake_type,
                     const DataBuffer& record, DataBuffer* output,
                     size_t length_offset, size_t value_offset, bool* changed);
};

// Make a copy of the first instance of a handshake message.
class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
 public:
  TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
      : handshake_type_(handshake_type), buffer_() {}

  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
                               const DataBuffer& input, DataBuffer* output);

  const DataBuffer& buffer() const { return buffer_; }

 private:
  uint8_t handshake_type_;
  DataBuffer buffer_;
};

// Records an alert.  If an alert has already been recorded, it won't save the
// new alert unless the old alert is a warning and the new one is fatal.
class TlsAlertRecorder : public TlsRecordFilter {
 public:
  TlsAlertRecorder() : level_(255), description_(255) {}

  virtual bool FilterRecord(uint8_t content_type, uint16_t version,
                            const DataBuffer& input, DataBuffer* output);

  uint8_t level() const { return level_; }
  uint8_t description() const { return description_; }

 private:
  uint8_t level_;
  uint8_t description_;
};

// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
 public:
  ChainedPacketFilter() {}
  ChainedPacketFilter(const std::vector<PacketFilter*> filters)
      : filters_(filters.begin(), filters.end()) {}
  virtual ~ChainedPacketFilter();

  virtual bool Filter(const DataBuffer& input, DataBuffer* output);

  // Takes ownership of the filter.
  void Add(PacketFilter* filter) {
    filters_.push_back(filter);
  }

 private:
  std::vector<PacketFilter*> filters_;
};

}  // namespace nss_test

#endif