#pragma once

#include <ios>
#include <streambuf>
#include <iostream>
#include <vector>
#using <System.Core.dll>
#include <vcclr.h>

// ref http://www.jah.ne.jp/~naoyuki/Writings/ExtIos.html

template <class Ch, class Tr = std::char_traits<Ch> >
class basic_AESfilterbuf : public std::basic_streambuf<Ch, Tr> {
public:
	// o̓RXgN^
	basic_AESfilterbuf(std::ostream& os)
	{
		out = &os;
		in = NULL;

		sbuffer.resize(4096);
		buffer = &sbuffer[0];
		size = sbuffer.size();
		setp(buffer, buffer + size);

		setupcrypt();
		out->write(salt.data(), salt.size());

		cStream = gcnew System::Security::Cryptography::CryptoStream(mStream,
			aes->CreateEncryptor(),
			System::Security::Cryptography::CryptoStreamMode::Write);
	}
	// ̓RXgN^
	basic_AESfilterbuf(std::istream& is)
	{
		out = NULL;
		in = &is;

		sbuffer.resize(4096);
		buffer = &sbuffer[0];
		size = sbuffer.size();
		setg(buffer, buffer, buffer);

		salt.resize(8);
		in->read(&salt[0], salt.size());
		finished = (in->gcount() != salt.size());
		salt.resize(in->gcount());
		setupcrypt();

		cStream = gcnew System::Security::Cryptography::CryptoStream(mStream,
			aes->CreateDecryptor(),
			System::Security::Cryptography::CryptoStreamMode::Read);
		std::vector<char> buf;
		buf.resize(size);
		while (true){
			in->read(&buf[0], size);
			buf.resize((unsigned int)in->gcount());
			if (in->gcount() == 0) break;

			array<System::Byte>^ tmpbuf = gcnew array<System::Byte>((int)buf.size());
			System::Runtime::InteropServices::Marshal::Copy((System::IntPtr)&buf[0], tmpbuf, 0, (int)buf.size());
			mStream->Write(tmpbuf, 0, tmpbuf->Length);

			if (in->gcount() != size)
				break;
		}
		mStream->Position = 0;
	}
	// fXgN^
	~basic_AESfilterbuf(void)
	{
		finished = true;
		sync();
	}

protected:
	// overflow
	int_type overflow(int_type c = Tr::eof())
	{
		encrypt();
		if (c != Tr::eof()){
			*pptr() = Tr::to_char_type(c);
			pbump(1);
			return Tr::not_eof(c);
		}
		else {
			return Tr::eof();
		}
	}

	// underflow
	int_type underflow(void)
	{
		if (egptr() <= gptr()){
			if (finished){
				return Tr::eof();
			}
			decrypt();
			if (egptr() <= gptr()){
				return Tr::eof();
			}
		}
		return Tr::to_int_type(*gptr());
	}

	// sync
	int sync(void)
	{
		if (in != NULL){
		}
		if (out != NULL){
			encrypt();
		}

		return 0;
	}

protected:
	// setbuf
	std::basic_streambuf<Ch, Tr>* setbuf(Ch* b, int s)
	{
		if (out != NULL){
			sbuffer.resize(0);
			setp(b, b, b + s);
		}
		if (in != NULL){
			dbuffer.resize(0);
			setg(b, b, b + s);
		}
		buffer = b;
		size = s;
		return this;
	}

	void setupcrypt(void)
	{
		System::Security::Cryptography::Rfc2898DeriveBytes^ deriveBytes;
		if (salt.size() != 8){
			array<System::Byte>^ saltarr = gcnew array<System::Byte>(8);
			auto rngCsp = gcnew System::Security::Cryptography::RNGCryptoServiceProvider();
			rngCsp->GetBytes(saltarr);
			salt.resize(8);
			System::Runtime::InteropServices::Marshal::Copy(saltarr, 0, (System::IntPtr)&salt[0], 8);
			deriveBytes = gcnew System::Security::Cryptography::Rfc2898DeriveBytes("password", saltarr);
		}
		else{
			array<System::Byte>^ saltarr = gcnew array<System::Byte>(8);
			salt.resize(8);
			System::Runtime::InteropServices::Marshal::Copy((System::IntPtr)&salt[0], saltarr, 0, 8);
			deriveBytes = gcnew System::Security::Cryptography::Rfc2898DeriveBytes("password", saltarr);
		}

		aes = gcnew System::Security::Cryptography::AesCryptoServiceProvider();
		aes->Mode = System::Security::Cryptography::CipherMode::CBC;
		aes->Padding = System::Security::Cryptography::PaddingMode::PKCS7;
		aes->Key = deriveBytes->GetBytes(aes->KeySize / 8);
		aes->IV = deriveBytes->GetBytes(aes->BlockSize / 8);

		mStream = gcnew System::IO::MemoryStream();
	}
	// encrypt
	void encrypt(void)
	{
		int_type len = (int_type)(pptr() - pbase());
		array<System::Byte>^ tmpbuf = gcnew array<System::Byte>(len);
		System::Runtime::InteropServices::Marshal::Copy((System::IntPtr)pbase(), tmpbuf, 0, tmpbuf->Length);
		cStream->Write(tmpbuf, 0, tmpbuf->Length);
		if (finished) {
			cStream->FlushFinalBlock();

			auto Data = mStream->ToArray();
			pin_ptr<System::Byte> dp = &Data[0];
			out->write((char *)dp, Data->Length);
			dp = nullptr;
		}

		setp(buffer, buffer + size);
	}

	// decrypt
	void decrypt(void)
	{
		array<System::Byte>^ decbuf = gcnew array<System::Byte>((int)size);
		auto len = cStream->Read(decbuf, 0, decbuf->Length);
		finished = (len != size);
		System::Runtime::InteropServices::Marshal::Copy(decbuf, 0, (System::IntPtr)eback(), len);

		setg(eback(), eback(), eback() + len);
	}

private:
	std::ostream* out;
	std::istream* in;

	char    *buffer;
	size_t	size;
	bool	finished;

	std::vector<char> sbuffer;
	std::vector<char> salt;
	gcroot<System::Security::Cryptography::AesCryptoServiceProvider^> aes;
	gcroot<System::IO::MemoryStream^> mStream;
	gcroot<System::Security::Cryptography::CryptoStream^> cStream;
};

template <class Ch, class Tr = std::char_traits<Ch> >
class basic_oAESfilter : public std::basic_ostream<Ch, Tr> {
public:
	basic_oAESfilter(std::ostream& os)
		: std::basic_ostream<Ch, Tr>(new basic_AESfilterbuf<Ch, Tr>(os))
	{
		}

	~basic_oAESfilter(void)
	{
		delete rdbuf();
	}
};

template <class Ch, class Tr = std::char_traits<Ch> >
class basic_iAESfilter : public std::basic_istream<Ch, Tr> {
public:
	basic_iAESfilter(std::istream& is)
		: std::basic_istream<Ch, Tr>(new basic_AESfilterbuf<Ch, Tr>(is))
	{
		}

	~basic_iAESfilter(void)
	{
		delete rdbuf();
	}
};

typedef basic_oAESfilter<char> oAESfilter;
typedef basic_iAESfilter<char> iAESfilter;
