import logging
from collections import defaultdict
from typing import TypeVar, Any

from netex import (
    Route,
    ServiceJourneyPattern,
    Line,
    PassengerStopAssignment,
    ScheduledStopPoint,
    EntityStructure,
    DayTypeAssignment,
    DayType,
    UicOperatingPeriod,
    PublicationDelivery,
    TypeOfFrameRef,
    ResponsibilitySet,
    StopPointInJourneyPattern,
)
from netexio.attributes import update_attr
from netexio.database import Database
from netexio.dbaccess import recursive_resolve, load_referencing_inwards, load_generator, setup_database
from netexio.pickleserializer import MyPickleSerializer
from netexio.xml import export_publication_delivery_xml
from transformers.epip import export_epip_network_offer
from transformers.references import split_path
from utils.profiles import EPIP_CLASSES
from utils.aux_logging import log_all, prepare_logger
from utils.utils import get_interesting_classes

import multiprocessing
from concurrent.futures import ProcessPoolExecutor

Tid = TypeVar("Tid", bound=EntityStructure)


def generate_epip_line_split(source_database_file: str, target_database_file: str, object_type: str, object_id: str) -> None:
    filter_set = {Line, ServiceJourneyPattern, DayType, ScheduledStopPoint}
    filter_set_assignment = {DayType: {DayTypeAssignment}, ScheduledStopPoint: {PassengerStopAssignment}}
    with Database(source_database_file, serializer=MyPickleSerializer(compression=True), readonly=True, multithreaded=True) as db_read:
        clazz = db_read.get_class_by_name(object_type)
        split_by = db_read.get_single(clazz, object_id)

        a, b = target_database_file.split('.lmdb')
        new_target_database_file = a + '_' + split_by.id.replace(':', '_') + '.lmdb'
        new_xml_file = a + '_' + split_by.id.replace(':', '_') + '.xml.gz'

        with Database(new_target_database_file, serializer=MyPickleSerializer(compression=True), readonly=False) as db_write:
            setup_database(db_write, classes=get_interesting_classes(EPIP_CLASSES), clean=True)

            # TODO: This is memory intensive, ideally we only keep what we have resolved and yield the objects to write them into the database
            resolved: list[Any] = []
            recursive_resolve(db_read, split_by, resolved, split_by.id, filter_set, filter_set_assignment=filter_set_assignment)

            for obj in resolved:
                db_write.insert_one_object(obj)

            result: dict[tuple[str, str, Any], list[str]] = defaultdict(list)

            db_write.block_until_done()

            # TODO: For now EPIP
            # TODO: It seems that the ValueSet for some reason removes BISON:TypeOfResponsibilityRole:financing
            removable_classes = db_write.tables() - EPIP_CLASSES
            for removable_class in removable_classes:
                for parent_id, parent_version, parent_class, path in load_referencing_inwards(db_write, removable_class):
                    parent_klass: type[Any] = db_write.get_class_by_name(parent_class)  # TODO: refactor at load_referencing_*
                    if parent_klass in EPIP_CLASSES:
                        # Aggregate all parent_ids, so we prevent concurrency issues, and the cost of deserialisation and serialisation
                        key = (parent_id, parent_version, parent_klass)
                        result[key].append(path)
                        # print(removable_class, key, path)

            # TODO: Once removed the export should have less elements in the GeneralFrame, and only the relevant extra elements
            for key, paths in result.items():
                parent_id, parent_version, parent_klass = key
                obj = db_write.get_single(parent_klass, parent_id, parent_version)
                if obj:
                    for path in paths:
                        split = split_path(path)
                        update_attr(obj, split, None)

                    print("REMOVED", obj.id, paths)

                    db_write.insert_one_object(obj, delete_embedding=True)

                    # print("SHOULD REMOVE", parent_klass, parent_id, parent_version, paths)

                else:
                    print("MISSING", parent_klass, parent_id, parent_version, paths)

            # db_write.block_until_done()
            # rs: ResponsibilitySet = db_write.get_single(ResponsibilitySet, "RET:ResponsibilitySet:Partition_ALL")
            # rs.roles.responsibility_role_assignment[0].responsible_area_ref.version = rs.version
            # db_write.insert_one_object(rs, delete_embedding=True)

            db_write.block_until_done()

            publication_delivery: PublicationDelivery = export_epip_network_offer(
                db_write, composite_frame_id=split_by.id, type_of_frame_ref=TypeOfFrameRef(ref='epip:EU_PI_LINE_OFFER', version_ref='1.0')
            )
            export_publication_delivery_xml(publication_delivery, new_xml_file)
            print(new_xml_file)


def _process_wrapper(args):
    source_database_file, target_database_file, object_type, object_id = args
    generate_epip_line_split(source_database_file, target_database_file, object_type, object_id)


def main(source_database_file: str, target_database_file: str, object_type: str) -> None:
    object_ids = []

    with Database(source_database_file, serializer=MyPickleSerializer(compression=True), readonly=True) as db_read:
        db_read.stats()

        split_by: Tid
        for split_by in load_generator(db_read, db_read.get_class_by_name(object_type)):
            assert split_by.id is not None
            object_ids.append(split_by.id)

    cpu_count = multiprocessing.cpu_count() - 1 or 1
    args = [(source_database_file, target_database_file, object_type, object_id) for object_id in object_ids]

    with ProcessPoolExecutor(max_workers=cpu_count) as executor:
        futures = [executor.submit(generate_epip_line_split, *args) for args in args]
        for future in futures:
            try:
                future.result()
            except Exception as e:
                print(f"Fout bij taak: {e}")


if __name__ == "__main__":
    import argparse
    import traceback

    parser = argparse.ArgumentParser(description="Filter the input by an object")
    parser.add_argument("source", type=str, help="lmdb file to use as input of the transformation.")

    parser.add_argument('object_type', type=str, help='The NeTEx object type to filter, for example ServiceJourney')

    parser.add_argument(
        "target",
        type=str,
        help="lmdb file to overwrite and store contents of the transformation.",
    )

    parser.add_argument("--log_file", type=str, required=False, help="the logfile")
    args = parser.parse_args()
    mylogger = prepare_logger(logging.INFO, args.log_file)
    try:
        main(args.source, args.target, args.object_type)
    except Exception as e:
        log_all(logging.ERROR, f"{e} {traceback.format_exc()}")
        raise e
