#pragma once

#include <ios>
#include <streambuf>
#include <iostream>
#include <vector>
#include <map>

#include <vcclr.h>
#using <mscorlib.dll>
#include <Windows.h>

namespace HashFilter {
	typedef enum { 
		SHA512 = 0x01, 
		SHA384 = 0x02, 
		SHA256 = 0x04, 
		SHA1   = 0x08, 
		RIPEMD160 = 0x10, 
		MD5    = 0x20
	} Hash;

	const int buffersize = 64 * 1024 * 1024;

	class hashfilterbuf : public std::streambuf {
	public:
		hashfilterbuf(Hash hashtype)
			: buffer((char*)VirtualAlloc(NULL, buffersize, MEM_COMMIT, PAGE_READWRITE), [](char* p) {::VirtualFree(p, 0, MEM_RELEASE); })
		{
			type = hashtype;
			if ((type & Hash::SHA512) != 0)
				hashlist[Hash::SHA512] = gcnew System::Security::Cryptography::SHA512Managed();
			if ((type & Hash::SHA384) != 0)
				hashlist[Hash::SHA384] = gcnew System::Security::Cryptography::SHA384Managed();
			if ((type & Hash::SHA256) != 0)
				hashlist[Hash::SHA256] = gcnew System::Security::Cryptography::SHA256Managed();
			if ((type & Hash::SHA1) != 0)
				hashlist[Hash::SHA1] = gcnew System::Security::Cryptography::SHA1CryptoServiceProvider();
			if ((type & Hash::RIPEMD160) != 0)
				hashlist[Hash::RIPEMD160] = gcnew System::Security::Cryptography::RIPEMD160Managed();
			if ((type & Hash::MD5) != 0)
				hashlist[Hash::MD5] = gcnew System::Security::Cryptography::MD5CryptoServiceProvider();

			setp(buffer.get(), buffer.get() + buffersize);
			buffer_basepos = 0;
			hashed_pos = 0;
			buffer_max = 0;
			bufferslot[buffer_basepos] = buffer;
			slotdone[buffer_basepos] = false;
		}
		~hashfilterbuf(void)
		{
			buffer.reset();
		}
		std::string GetHash(Hash hashtype){
			const char table[] = "0123456789abcdef";
			auto hash = hashlist[hashtype];
			if (static_cast<System::Security::Cryptography::HashAlgorithm^>(hash) == nullptr) return "";
			std::string hashstr;
			auto hashbytearray = hash->Hash;
			pin_ptr<System::Byte> hashp = &hashbytearray[0];
			for (int i = 0; i < hashbytearray->Length; i++){
				hashstr += table[hashp[i] >> 4];
				hashstr += table[hashp[i] & 0x0f];
			}
			return hashstr;
		}
	protected:
		// overflow
		int_type overflow(int_type c = std::char_traits<char>::eof())
		{
			pos_type slotmask = ~(pos_type)(buffersize - 1);
			// obNobt@ɕۑ
			bufferslot[buffer_basepos] = buffer;
			slotdone[buffer_basepos] = true;
			// ̃obt@p
			buffer_basepos += pptr() - pbase();
			buffer_max = (buffer_basepos > buffer_max) ? buffer_basepos : buffer_basepos;
			if (bufferslot.find(buffer_basepos & slotmask) == bufferslot.end()){
				bufferslot[buffer_basepos & slotmask] = std::shared_ptr<char>((char*)VirtualAlloc(NULL, buffersize, MEM_COMMIT, PAGE_READWRITE), [](char* p) {::VirtualFree(p, 0, MEM_RELEASE); });
			}
			// obNobt@珑߂
			buffer = bufferslot[buffer_basepos & slotmask];
			setp(buffer.get(), buffer.get() + buffersize);

			// ݂̕
			*pptr() = std::char_traits<char>::to_char_type(c);
			pbump(1);

			TransformBlock();
			return c;
		}

		// sync
		int sync(void)
		{
			TransformFinalBlock();
			return 0;
		}

		// seekoff
		pos_type seekoff(off_type off, std::ios_base::seekdir way,
			std::ios_base::openmode which = std::ios_base::in | std::ios_base::out)
		{
			pos_type pos;
			pos_type targetpos;

			// ݂A擪炩݈ʒȗ΂̂݋
			if ((which & std::ios_base::in) != 0) 
				return std::ios::pos_type(std::ios::off_type(-1));

			switch (way){
			case std::ios_base::beg:
				targetpos = off;
				break;
			case std::ios_base::end:
				return std::ios::pos_type(std::ios::off_type(-1));
			case std::ios_base::cur:
				targetpos = buffer_basepos + (pos_type)(pptr() - pbase()) + off;
				break;
			}
			if (((targetpos - buffer_basepos) >=0) && ((targetpos - buffer_basepos) < buffersize)){
				pbump((int)off);
			}
			else {
				ChangeBufferSlot(targetpos);
			}
			TransformBlock();
			return targetpos;
		}

		// seekpos
		pos_type seekpos(pos_type pos, std::ios_base::openmode which = std::ios_base::in | std::ios_base::out)
		{
			return seekoff(off_type(pos), std::ios_base::beg, which);
		}

	protected:
		void ChangeBufferSlot(pos_type newpos){
			pos_type slotmask = ~(pos_type)(buffersize - 1);
			// ړ̃obNobt@Ȃ΍
			if (bufferslot.find(newpos & slotmask) == bufferslot.end()){
				bufferslot[newpos & slotmask] = std::shared_ptr<char>((char*)VirtualAlloc(NULL, buffersize, MEM_COMMIT, PAGE_READWRITE), [](char* p) {::VirtualFree(p, 0, MEM_RELEASE); });
			}
			// ݂̃obNobt@ɕۑ
			bufferslot[buffer_basepos] = buffer;
			slotdone[buffer_basepos] = (pptr() == epptr());
			pos_type nowpos = pptr() - pbase() + buffer_basepos;
			buffer_max = (nowpos > buffer_max) ? nowpos : buffer_max;
			// ړ̃obNobt@ǂݏo
			buffer_basepos = newpos & slotmask;
			buffer = bufferslot[buffer_basepos];
			setp(buffer.get(), buffer.get() + newpos - buffer_basepos, buffer.get() + buffersize);
			buffer_max = (newpos > buffer_max) ? newpos : buffer_max;
		}
		void TransformBlock(void){
			while (hashed_pos < buffer_basepos){
				if (!slotdone[hashed_pos]) break;

				array<System::Byte>^ tmpbuf = gcnew array<System::Byte>(buffersize);
				System::Runtime::InteropServices::Marshal::Copy((System::IntPtr)bufferslot[hashed_pos].get(), tmpbuf, 0, tmpbuf->Length);
				for (auto it = hashlist.begin(); it != hashlist.end(); ++it){
					it->second->TransformBlock(tmpbuf, 0, tmpbuf->Length, nullptr, 0);
				}
				// gÎŏ
				bufferslot[hashed_pos].reset();
				bufferslot.erase(hashed_pos);
				slotdone.erase(hashed_pos);
				hashed_pos += tmpbuf->Length;
			}
		}
		void TransformFinalBlock(void){
			pos_type slotmask = ~(pos_type)(buffersize - 1);
			// ݂̃obNobt@ɕۑ
			bufferslot[buffer_basepos] = buffer;
			slotdone[buffer_basepos] = false;
			pos_type nowpos = pptr() - pbase() + buffer_basepos;
			buffer_max = (nowpos > buffer_max) ? nowpos : buffer_max;

			while (hashed_pos < buffer_max){
				auto len = buffer_max - hashed_pos;
				len = (len > buffersize) ? buffersize : len;

				// obNobt@Ȃ΍
				if (bufferslot.find(hashed_pos) == bufferslot.end()){
					bufferslot[hashed_pos] = std::shared_ptr<char>((char*)VirtualAlloc(NULL, buffersize, MEM_COMMIT, PAGE_READWRITE), [](char* p) {::VirtualFree(p, 0, MEM_RELEASE); });
				}

				array<System::Byte>^ tmpbuf = gcnew array<System::Byte>((int)len);
				System::Runtime::InteropServices::Marshal::Copy((System::IntPtr)bufferslot[hashed_pos].get(), tmpbuf, 0, tmpbuf->Length);
				for (auto it = hashlist.begin(); it != hashlist.end(); ++it){
					it->second->TransformBlock(tmpbuf, 0, tmpbuf->Length, nullptr, 0);
				}
				// gÎŏ
				bufferslot[hashed_pos].reset();
				bufferslot.erase(hashed_pos);
				slotdone.erase(hashed_pos);
				hashed_pos += tmpbuf->Length;
			}
			for (auto it = hashlist.begin(); it != hashlist.end(); ++it){
				array<System::Byte>^ tmpbuf = gcnew array<System::Byte>(0);
				it->second->TransformFinalBlock(tmpbuf, 0, 0);
			}
		}
	private:
		Hash type;
		std::map<Hash,gcroot<System::Security::Cryptography::HashAlgorithm^>> hashlist;

		std::shared_ptr<char> buffer;
		std::map<pos_type, std::shared_ptr<char>> bufferslot;
		std::map<pos_type, bool> slotdone;
		pos_type buffer_basepos;
		pos_type hashed_pos;
		pos_type buffer_max;
	};

	class hashfilter : public std::ostream {
	private:
		hashfilterbuf streambuf;
	public:
		hashfilter(Hash hashtype = (Hash)0)
			: streambuf(hashtype)
			, std::ostream(&streambuf)
		{
		}
		std::string sha512(void)
		{
			return streambuf.GetHash(Hash::SHA512);
		}
		std::string sha384(void)
		{
			return streambuf.GetHash(Hash::SHA384);
		}
		std::string sha256(void)
		{
			return streambuf.GetHash(Hash::SHA256);
		}
		std::string sha1(void)
		{
			return streambuf.GetHash(Hash::SHA1);
		}
		std::string ripemd160(void)
		{
			return streambuf.GetHash(Hash::RIPEMD160);
		}
		std::string md5(void)
		{
			return streambuf.GetHash(Hash::MD5);
		}
	};

}