/*****************************************************************************

        mod_mid.cpp
        Author: Laurent de Soras, 2022

Rationale for song loading:

- MIDI channels are mapped to instruments with the same index (1-16)
- Each MIDI channel occupies a fixed set of tracks. We have to scan all MIDI
	tracks to find the maximum activity (simultaneous playing notes) for each
	channel.
- Default score resolution: 1 quarter = 8 rows at the current GT speed (ticks
	per row). This means the tempo is automatically doubled for the default
	speed of 6.
- Pattern and MIDI tracks:
	- In formats 0 and 2, each MIDI track represents one pattern (or more if it
		is too long)
	- In format 1, all MIDI tracks are merged together in a single virtual
		pattern of infinite length, then cut at MIDI track boundaries and
		maximum GT pattern length (256 lines), probably rounded to the closest
		quarter or bar.

--- Legal stuff ---

This program is free software. It comes without any warranty, to
the extent permitted by applicable law. You can redistribute it
and/or modify it under the terms of the Do What The Fuck You Want
To Public License, Version 2, as published by Sam Hocevar. See
http://www.wtfpl.net/ for more details.

*Tab=3***********************************************************************/



/*\\\ INCLUDE FILES \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/

#include "fstb/def.h"
#include "fstb/fnc.h"
#include "inst.h"
#include "log.h"
#include "MidiFile.h"
#include "MidiVeloMapper.h"
#include "mod_mid.h"
#include "mods_ct.h"
#include "Player.h"
#include "song.h"

#include <array>
#include <memory>
#include <numeric>
#include <set>
#include <vector>

#include <cassert>
#include <cstdint>



/*\\\ CONSTANTS \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



// Dynamic range for the velocity
static constexpr int MODMID_velo_dr = 40;



/*\\\ TYPES & STRUCTURES \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



typedef std::array <int, midi::_nbr_chn> TrackCount;

// Contains note end timestamps
typedef std::set <int32_t> NoteEndSet;

class MidiNote
{
public:
	MidiFileTrack::Note
	               _note;
	// Track index within the channel. < 0: invalid (not set yet)
	int            _trk_idx = -1;
};

class MidiChannel
{
public:
	std::vector <MidiNote>
	               _note_arr;
	// Number of internal tracks. < 0:invalid (not set yet)
	int            _nbr_tracks = -1;
	// End timestamp of the last event. Defines the minimum time boundaries
	// < 0: invalid (not set yet)
	int            _ts_end     = -1;
};

class MidiPattern
{
public:
	std::array <MidiChannel, midi::_nbr_chn>
	               _chn_arr;
	MidiFileTrack::TempoMap
	               _tempo_map;
	MidiFileTrack::TimeSigMap
	               _time_sig_map;
	// List of timestamps where the pattern could be split (MIDI file format 1)
	std::vector <int32_t>
	               _split_list;
};

typedef std::vector <MidiPattern> PatternList;



class TimeRef
{
public:
	MidiFile::TimeInfo
	                  _midi;

	int               _speed             = 6; // Ticks per line
	int               _lines_per_quarter = 8;

	// Tempo multiplier: original tempo -> GT tempo
	double            _tempo_mul         = 1;
};

class GtPatTime
{
public:
	int            _line = 0;
	int            _tick = 0;
};



class PatTmp
{
public:
	typedef std::vector <MODS_GT2_SPL_NOTE> Track;
	typedef std::vector <Track> TrackArray;

	explicit       PatTmp (int nbr_tracks, int nbr_lines);
	int            get_nbr_lines () const noexcept;

	TrackArray     _track_arr;
};

PatTmp::PatTmp (int nbr_tracks, int nbr_lines)
:	_track_arr (nbr_tracks, Track (nbr_lines))
{
	assert (nbr_tracks > 0);
	assert (nbr_lines > 0);
}

int	PatTmp::get_nbr_lines () const noexcept
{
	return int (_track_arr [0].size ());
}



/*\\\ PROTOTYPES \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



static int	load_file_in_memory (std::vector <uint8_t> &content, const char filename_0 []);
static int	load_file_in_memory (std::vector <uint8_t> &content, FILE *f_ptr);

static PatternList	reorganize_patterns (const MidiFile &mf);
static TrackCount	count_tracks (PatternList &pat_list);
static int	count_tracks (MidiChannel &chn);
static double	find_tempo (const MidiFile &mf);
static PatTmp	build_one_pattern (const MidiPattern &mp, const TimeRef &tref, const TrackCount &tc);
static int32_t	find_end_timestamp (const MidiPattern &mp);
static GtPatTime	conv_time (int32_t timestamp, const MidiPattern &mp, const TimeRef &tref);
static int	add_pattern_part (const PatTmp &pat_tmp, int pat_idx, int line_beg, int line_end);



/*\\\ PUBLIC FUNCTIONS \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



int	MODMID_save_module (FILE *file_ptr)
{
	return MODMID_save_song (file_ptr);
}



int	MODMID_save_song (FILE *file_ptr)
{
	fstb::unused (file_ptr);

	assert (file_ptr != nullptr);

	int            ret_val = 0;

	/*** To do ***/
	assert (false);

	return ret_val;
}



int	MODMID_load_module (FILE *file_ptr, BYTE temp_info [MODS_TEMP_INFO_LEN])
{
	return MODMID_load_song (file_ptr, temp_info);
}



int	MODMID_load_song (FILE *file_ptr, BYTE temp_info [MODS_TEMP_INFO_LEN])
{
	fstb::unused (temp_info);

	assert (file_ptr != nullptr);

	// Loads and parse the file
	std::vector <uint8_t>   content;
	int            ret_val = load_file_in_memory (content, file_ptr);
	if (ret_val != 0)
	{
		LOG_printf ("MODMID_load_song: cannot open and load file.\n");
	}
	std::unique_ptr <MidiFile>  midi_uptr;
	try
	{
		midi_uptr = std::make_unique <MidiFile> (
			content.data (), content.data () + content.size (),
			MidiFileTrack::NoteRepeat::_accumulate
		);
	}
	catch (std::exception &e)
	{
		LOG_printf (
			"MODMID_load_song: error during file parsing: %s\n", e.what ()
		);
		return -1;
	}
	catch (...)
	{
		LOG_printf ("MODMID_load_song: error during file parsing.\n");
		return -1;
	}
	const auto &   mf = *midi_uptr;

	auto           pat_list = reorganize_patterns (mf);

	// Finds the number of tracks for all channels
	const auto     tc = count_tracks (pat_list);

	LOG_printf ("MODMID_load_song: tracks per channel:");
	for (auto n : tc)
	{
		LOG_printf (" %d", int (n));
	}
	LOG_printf ("\n");

	// Checks the total number of tracks
	const auto     total_nbr_tracks = std::accumulate (tc.begin (), tc.end (), 0);
	if (total_nbr_tracks == 0)
	{
		LOG_printf ("Aborting: file doesn\'t contain any note.\n");
		return 0;
	}
	else if (total_nbr_tracks > GTK_NBRTRACKS_MAXI)
	{
		LOG_printf ("Too many tracks.\n");
		return (-1);
	}

	// Sets the number of tracks
	if (PAT_set_nbr_tracks (Pattern_TYPE_SPL, total_nbr_tracks) != 0)
	{
		LOG_printf (
			"Error: cannot change the number of tracks to %d.\n", total_nbr_tracks
		);
		return (-1);
	}

	TimeRef        tref;
	tref._midi  = mf.get_time_info ();
	Player &			player = Player::use_instance ();
	tref._speed = player.get_speed ();
	// Compute a tempo multiplier depending of:
	// - the desired number of lines per quarter
	// - the current speed (ticks per line)
	// GT reference is 24 ticks per quarter, so if we want 48 ticks/quarter,
	// we have to double the tempo.
	tref._tempo_mul = double (tref._speed * tref._lines_per_quarter) / (6 * 4);

	// Build the patterns
	int            pat_idx = 0;
	for (const auto &mp : pat_list)
	{
		// Converts a MIDI pattern into a virtual GT pattern
		const auto     pat_tmp   = build_one_pattern (mp, tref, tc);
		const auto     nbr_lines = pat_tmp.get_nbr_lines ();

		// Now finds split lines the GT patterns
		std::set <int> line_set;
		for (auto split_ts : mp._split_list)
		{
			const auto     time_pat = conv_time (split_ts, mp, tref);
			const int      line     = time_pat._line; // Rounds down (ignore ticks)
			if (line > 0 && line < nbr_lines)
			{
				line_set.insert (line);
			}
		}
		line_set.insert (nbr_lines); // Terminates with the pattern end

		// Second split to avoid exceeding the maximum GT pattern length
		int            prev_split_line = 0;
		for (auto it = line_set.begin (); it != line_set.end (); ++it)
		{
			while (*it - prev_split_line > GTK_NBRLINES_MAXI)
			{
				prev_split_line += GTK_NBRLINES_MAXI;
				line_set.insert (prev_split_line);
			}
		}

		// Splits the pattern and writes it in the GT document
		int         line_beg = 0;
		for (auto line_end : line_set)
		{
			ret_val = add_pattern_part (pat_tmp, pat_idx, line_beg, line_end);
			if (ret_val != 0)
			{
				return ret_val;
			}
			++ pat_idx;
			line_beg = line_end;
		}
	}

	// Builds a song with all the pattern
	SONG_set_song_length (pat_idx);
	SONG_set_song_repeat (0);
	for (int pos = 0; pos < pat_idx; ++pos)
	{
		SONG_set_pattern_number (pos, pos);
	}

	// Finds and compute a default tempo
	auto           tempo = find_tempo (mf);
	if (tempo <= 0)
	{
		tempo = 120.0;
	}
	tempo *= tref._tempo_mul;
	tempo  = fstb::limit (
		tempo, double (Player::MIN_TEMPO), double (Player::MAX_TEMPO)
	);

	// Sets real-time data
	{
		std::lock_guard <std::mutex>	lock (GTK_mutex);
		player.set_tempo (tempo);
	}



	LOG_printf ("MIDI file loaded.\n\n");

	return 0;
}



bool	MODMID_detect_format (FILE *file_ptr, const void *header_ptr, long header_length, BYTE temp_info [MODS_TEMP_INFO_LEN])
{
	fstb::unused (file_ptr, temp_info);

	const auto     h_ptr = (const uint8_t *) header_ptr;

	if (   header_length >= 14
	    && h_ptr [0] == uint8_t ('M')
	    && h_ptr [1] == uint8_t ('T')
	    && h_ptr [2] == uint8_t ('h')
	    && h_ptr [3] == uint8_t ('d'))
	{
		return true;
	}

	return false;
}



/*\\\ PRIVATE FUNCTIONS \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/



static int	load_file_in_memory (std::vector <uint8_t> &content, const char filename_0 [])
{
	int            ret_val = 0;

	auto           f_ptr = fstb::fopen_utf8 (filename_0, "rb");
	if (f_ptr == nullptr)
	{
		ret_val = -1;
	}

	ret_val = load_file_in_memory (content, f_ptr);

	if (f_ptr != nullptr)
	{
		fclose (f_ptr);
	}

	return ret_val;
}



static int	load_file_in_memory (std::vector <uint8_t> &content, FILE *f_ptr)
{
	assert (f_ptr != nullptr);

	int            ret_val = 0;

	if (fseek (f_ptr, 0, SEEK_END) != 0)
	{
		ret_val = -1;
	}
	long           len = 0;
	if (ret_val == 0)
	{
		len = ftell (f_ptr);
		if (fseek (f_ptr, 0, SEEK_SET) != 0)
		{
			ret_val = -1;
		}
	}

	if (ret_val == 0)
	{
		content.resize (size_t (len));
		if (fread (content.data (), len, 1, f_ptr) != 1)
		{
			ret_val = -1;
		}
	}

	return ret_val;
}



// With format 1 ("multi"), we merge all the MIDI tracks.
// With other formats, the MIDI tracks are handled individually
static PatternList	reorganize_patterns (const MidiFile &mf)
{
	PatternList    pat_list;

	const auto     format   = mf.get_format ();
	const int      nbr_mtrk = mf.get_nbr_tracks ();
	for (int mtrk_idx = 0; mtrk_idx < nbr_mtrk; ++mtrk_idx)
	{
		const auto &   midi_trk = mf.use_track (mtrk_idx);

		// Merges everything into a single pattern, or use a new one depending
		// on the format
		pat_list.resize ((format == MidiFile::Format::_multi) ? 1 : mtrk_idx + 1);
		auto &         pattern = pat_list.back ();

		// Notes
		int32_t        ts_min  = 0;
		for (int chn_idx = 0; chn_idx < midi::_nbr_chn; ++chn_idx)
		{
			auto &         notes     = pattern._chn_arr [chn_idx];
			const auto     notes_tmp = midi_trk.get_notes (chn_idx);
			std::for_each (notes_tmp.begin (), notes_tmp.end (),
				[&notes, &ts_min] (const MidiFileTrack::Note &mn)
				{
					notes._note_arr.push_back ({ mn });
					ts_min = std::min (ts_min, mn._timestamp);
				}
			);

			if (format == MidiFile::Format::_multi && mtrk_idx == nbr_mtrk - 1)
			{
				std::sort (notes._note_arr.begin (), notes._note_arr.end (),
					[] (const MidiNote &lhs, const MidiNote &rhs)
					{
						return lhs._note._timestamp < rhs._note._timestamp;
					}
				);
			}
		}

		// Tempo
		const auto &   tempo_map = midi_trk.use_tempo_map ();
		pattern._tempo_map.insert (tempo_map.begin (), tempo_map.end ());

		// Time signature
		const auto &   time_sig_map = midi_trk.use_time_sig_map ();
		pattern._time_sig_map.insert (time_sig_map.begin (), time_sig_map.end ());

		// Pattern origin
		if (ts_min > 0)
		{
			pattern._split_list.push_back (ts_min);
		}
	}

	return pat_list;
}



static TrackCount	count_tracks (PatternList &pat_list)
{
	TrackCount     tc {};

	for (auto &pattern : pat_list)
	{
		for (int chn_idx = 0; chn_idx < midi::_nbr_chn; ++chn_idx)
		{
			auto &         channel    = pattern._chn_arr [chn_idx];
			const auto     nbr_tracks = count_tracks (channel);
			tc [chn_idx] = std::max (tc [chn_idx], nbr_tracks);
		}
	}

	return tc;
}



// Assumes notes are sorted by chronological Note On timestamps
// Fills chn._nbr_tracks, chn._ts_end and all chn._note_arr []._trk_idx
static int	count_tracks (MidiChannel &chn)
{
	// Negative values if the track isn't playing any note
	std::vector <int32_t> note_end_list;

	int32_t        ts_final = 0;

	for (auto &note : chn._note_arr)
	{
		// Before inserting the new note, removes the notes that were
		// terminated before its beginning.
		for (auto &ts : note_end_list)
		{
			// <= and not < because simultaneous Note Off -> Note On need only a
			// single track.
			if (ts >= 0 && ts <= note._note._timestamp)
			{
				ts = -1;
			}
		}

		// Sets the minimum duration to 1, so instantaneous hits are counted
		// correctly
		const auto     duration = std::max <int32_t> (note._note._duration, 1);
		const auto     ts_end   = note._note._timestamp + duration;
		ts_final = std::max (ts_final, ts_end);

		// Inserts the new note
		bool           inserted_flag = false;
		for (int trk_idx = 0
		;	! inserted_flag && trk_idx < int (note_end_list.size ())
		;	++trk_idx)
		{
			auto &         ts = note_end_list [trk_idx];
			if (ts < 0)
			{
				ts = ts_end;
				note._trk_idx = trk_idx;
				inserted_flag = true;
			}
		}
		if (! inserted_flag)
		{
			note._trk_idx = int (note_end_list.size ());
			note_end_list.push_back (ts_end);
		}
	}

	chn._nbr_tracks = int (note_end_list.size ());
	chn._ts_end     = ts_final;

	return chn._nbr_tracks;
}



// Returns 0 if not found
static double	find_tempo (const MidiFile &mf)
{
	double         tempo = 0;
	const auto     nbr_tracks = mf.get_nbr_tracks ();
	for (int trk_idx = 0; trk_idx < nbr_tracks; ++trk_idx)
	{
		const auto &   track = mf.use_track (trk_idx);
		const auto &   tempo_map = track.use_tempo_map ();
		if (! tempo_map.empty ())
		{
			tempo = tempo_map.begin ()->second;
			break;
		}
	}

	return tempo;
}



static PatTmp	build_one_pattern (const MidiPattern &mp, const TimeRef &tref, const TrackCount &tc)
{
	const auto     ts_end     = find_end_timestamp (mp);
	const int      nbr_lines  = conv_time (ts_end, mp, tref)._line;
	const auto     nbr_tracks = std::accumulate (tc.begin (), tc.end (), 0);

	PatTmp         pat (nbr_tracks, nbr_lines);

	int            trk_base = 0;
	for (int chn_idx = 0; chn_idx < midi::_nbr_chn; ++chn_idx)
	{
		const auto &   chn = mp._chn_arr [chn_idx];

		// Number of subtracks for this channel
		const auto     nbr_c_trk = tc [chn_idx];

		for (auto &note : chn._note_arr)
		{
			const auto    n_ts_beg = note._note._timestamp;
			const auto    n_ts_end = note._note._timestamp + note._note._duration;
			const auto    t_beg    = conv_time (n_ts_beg, mp, tref);
			auto          t_end    = conv_time (n_ts_end, mp, tref);

			// Checks the Note On/Off positions. In case of very short notes,
			// both could share the same line. Therefore if the Note On is delayed
			// relatively to the line beginning, we move the Note Off to the
			// next line. We always keep the Note On at its position because
			// it is generally more important, regarding the overall rythm.
			// BUT: t_end could now be located out of the pattern if it was
			// previously on the last line, so we'll have to check its position
			// before inserting the command.
			if (t_end._line == t_beg._line && t_beg._tick > 0)
			{
				++ t_end._line;
				t_end._tick = 0;
			}

			// Converts between MIDI note and GT note. Actually both should be
			// the same value.
			constexpr auto pitch_ofs = Sample_REF_A440 - midi::_note_a440;

			const auto     pitch = fstb::limit <int> (
				note._note._note + pitch_ofs, 1, GTK_NBRNOTES_MAXI - 1
			);
			using          VeloMap = MidiVeloMapper <MODMID_velo_dr>;
			const auto     vol     = fstb::round_int (
				VeloMap::conv_velo_to_gain (note._note._velo) * 0x40
			);

			// Finds a track where the note will play
			const auto     trk_idx = note._trk_idx;
			assert (trk_idx >= 0);
			assert (trk_idx < nbr_c_trk);
			auto &         track = pat._track_arr [trk_base + trk_idx];

			// Note On
			auto &         gt_on = track [t_beg._line];
			gt_on.note    = UBYTE (pitch);
			gt_on.instr   = UBYTE (chn_idx + 1);
			gt_on.volume  = UBYTE (0x10 + vol);
			gt_on.fxnum   = (t_beg._tick > 0) ? 0x09 : 0; // Note delay
			gt_on.fxparam = UBYTE (t_beg._tick);

			// Note Off
			if (   t_end._line >= nbr_lines
			    && (t_beg._line < nbr_lines - 1 || t_beg._tick > 0))
			{
				// Tries to fix a Note Off slightly out of the score
				t_end._line = nbr_lines - 1;
				t_end._tick = tref._speed - 1;
			}
			if (t_end._line < nbr_lines)
			{
				auto &         gt_off = track [t_end._line];
				gt_off.fxnum   = 0x0A; // Note cut
				gt_off.fxparam = UBYTE (t_end._tick);
			}
		}

		trk_base += nbr_c_trk;
	}

	return pat;
}



static int32_t	find_end_timestamp (const MidiPattern &mp)
{
	int32_t        ts_end = 0;

	for (const auto &chn : mp._chn_arr)
	{
		assert (chn._ts_end >= 0);
		ts_end = std::max (ts_end, chn._ts_end);
	}

	return ts_end;
}



// In SMPTE timing mode, we need an explicit tempo to map absolute time to
// musical beats.
static GtPatTime	conv_time (int32_t timestamp, const MidiPattern &mp, const TimeRef &tref)
{
	const auto     gt_t_p_q = tref._lines_per_quarter * tref._speed;

	// Musical time reference
	auto           conv_music = [&] () {
		// Just performs a tick conversion (MIDI file -> GT)
		const auto     num   = timestamp * gt_t_p_q;
		const auto     den   = tref._midi._nbr_ticks;
		const auto     r_cst = den >> 1;

		return int ((num + r_cst) / den);
	};

	// SMPTE time reference
	auto           conv_smpte = [&] () {
		// Multiplier to convert a timestamp into a time in seconds
		const auto     ts_mul =
			  double (tref._midi._smpte->_fps_num)
			/ double (tref._midi._smpte->_fps_den * tref._midi._nbr_ticks);

		// We will use the tempo map to convert the timestamps into quarters
		int32_t        cur_ts    = 0;
		double         cur_quart = 0;
		double         cur_tempo = 120; // Quarter per minute
		auto           next_change_it = mp._tempo_map.begin ();

		// Helper function to convert a timestamp into a quarter
		auto           ts_to_quarter = [&] (int32_t ts_target) {
			const auto     dif_ts = ts_target - cur_ts;
			const auto     dif_s  = double (dif_ts) * ts_mul;
			return dif_s * cur_tempo / 60.0;
		};

		// Runs all the tempo changes until we reach the segment containing our
		// desired timestamp
		while (   next_change_it != mp._tempo_map.end ()
		       && next_change_it->first < timestamp)
		{
			const auto     dif_quart = ts_to_quarter (next_change_it->first);
			cur_quart += dif_quart;
			cur_ts     = next_change_it->first;
			cur_tempo  = next_change_it->second;
			++ next_change_it;
		}

		// Finally converts the remaining ticks in the last segment
		const auto     dif_quart = ts_to_quarter (timestamp);
		cur_quart += dif_quart;

		// Quarters to lines/ticks
		return fstb::round_int (cur_quart * gt_t_p_q);
	};

	const auto     gt_tick = (tref._midi._smpte) ? conv_smpte () : conv_music ();
	const auto     line    = gt_tick / tref._speed;
	const auto     tick    = gt_tick - tref._speed * line;

	return GtPatTime { line, tick };
}



static int	add_pattern_part (const PatTmp &pat_tmp, int pat_idx, int line_beg, int line_end)
{
	assert (pat_idx >= 0);
	assert (line_beg >= 0);
	assert (line_beg < line_end);
	assert (line_end <= pat_tmp.get_nbr_lines ());

	const auto     nbr_lines = line_end - line_beg;
	if (PAT_set_pattern_height (pat_idx, nbr_lines))
	{
		LOG_printf ("MODMID_load_song: Error: couldn't set pattern height # %d to %d lines.\n",
					   pat_idx, nbr_lines);
		return -1;
	}

	const auto     nbr_tracks = int (pat_tmp._track_arr.size ());
	assert (nbr_tracks > 0);

	for (int line_idx = 0; line_idx < nbr_lines; ++line_idx)
	{
		for (int trk_idx = 0; trk_idx < nbr_tracks; ++trk_idx)
		{
			MODS_GT2_SPL_NOTE *  dst_ptr = static_cast <MODS_GT2_SPL_NOTE *> (
				PAT_get_note_adr_pat (Pattern_TYPE_SPL, pat_idx, line_idx, trk_idx)
			);
			*dst_ptr = pat_tmp._track_arr [trk_idx] [line_beg + line_idx];
		}
	}

	return 0;
}



/*\\\ EOF \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\*/
