# Copyright (c) 2020 -- Élie Michel <elie.michel@exppad.com>
#
# ##### BEGIN GPL LICENSE BLOCK #####
#
#  This program is free software; you can redistribute it and/or
#  modify it under the terms of the GNU General Public License
#  as published by the Free Software Foundation; either version 2
#  of the License, or (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software Foundation,
#  Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# ##### END GPL LICENSE BLOCK #####

bl_info = {
    "name": "LilyCaptureMerger",
    "author": "Élie Michel <elie.michel@exppad.com>",
    "version": (1, 0, 4),
    "description": "Merge collections of objects together based on the textures they share",
    "blender": (2, 90, 0),
    "location": "View3D > Object > LilyCaptureMerger",
    "wiki_url": "",
    "tracker_url": "",
    "warning": "",
    "category": "Object"
}

import bpy
from mathutils import Matrix
import os
from hashlib import md5

######################################################################

BASE_COLOR_INDEX = 0

def get_material_output(material):
    for node in material.node_tree.nodes:
        if node.type == "OUTPUT_MATERIAL":
            return node

def list_image_nodes(node, weight=0):
    if node.type == 'TEX_IMAGE':
        return [(node, weight)]
    image_nodes = []
    for i, in_socket in enumerate(node.inputs):
        w = weight
        if node.type == 'BSDF_PRINCIPLED' and i == BASE_COLOR_INDEX:
            w += 100
        for l in in_socket.links:
            image_nodes += list_image_nodes(l.from_node, weight=w - 1)
    return image_nodes

def get_image_node(obj):
    material = obj.material_slots[0].material
    material_output = get_material_output(material)
    image_nodes = list_image_nodes(material_output)
    image_nodes.sort(key=lambda x: -x[1])
    return image_nodes[0][0] if len(image_nodes) > 0 else None

def get_image_hash(img):
    if img.packed_file is not None:
        return md5(img.packed_file.data).hexdigest()
    elif os.path.isfile(img.filepath_from_user()):
        return md5(open(img.filepath_from_user(), 'rb').read()).hexdigest()
    else:
        print("Warning: could not get data for image {}".format(img.name))
        return ''

def get_source_dest(context):
    # Source is reference
    source_obj = context.active_object
    dest_obj = None
    selected = context.selected_objects
    for obj in selected:
        if obj != source_obj:
            dest_obj = obj
            break
    return source_obj, dest_obj

def is_valid(obj):
    """Return True iff the object is a mesh with a material contining a texture"""
    return obj.type == 'MESH' and len(obj.material_slots) > 0

def make_lut(collec):
    objects = [ obj for obj in collec.all_objects if is_valid(obj) ]
    source_image_nodes = [ get_image_node(obj) for obj in objects ]
    source_images = [ node.image for node in source_image_nodes ]
    lut = {}
    for obj, img in zip(objects, source_images):
        hash = get_image_hash(img)
        lut[hash] = obj
        #print(" - {}: {}".format(obj.name, hash))
    return lut

def vote_for(matrix, voter, votes):
    # yeah it is useless to use dicts of we iterate over it all at each insertion...
    for m in votes.keys():
        if matrix_distance(m, matrix) < 1e5:
            votes[m].append(voter)
            return
    votes[matrix] = [voter]

def get_argmax_vote(votes):
    arg = None
    count = 0
    argvoters = None
    for m, voters in votes.items():
        count = len(voters)
        if arg is None or count > maxi:
            maxi = count
            arg = m
            argvoters = voters
    return arg, argvoters

def matrix_distance(m1, m2):
    m = m1 @ m2.inverted() - Matrix()
    return sum([c.length for c in m])

def apply_matrix(collec_name, matrix, duplicates):
    collec = bpy.data.collections[collec_name]
    for obj in list(collec.all_objects):
        if obj in duplicates:
            bpy.data.objects.remove(obj, do_unlink=True)
        else:
            obj.matrix_world = matrix @ obj.matrix_world

def merge_captures(source_obj, dest_obj, remove_duplicates):
    source_collec = source_obj.users_collection[0]
    print("Source collection: {}".format(source_collec.name))
    source_lut = make_lut(source_collec)

    dest_collec = dest_obj.users_collection[0]
    dest_collec_name = dest_collec.name
    print("Destination collection: {}".format(dest_collec.name))
    dest_lut = make_lut(dest_collec)

    votes = {}
    for k in dest_lut.keys():
        if k in source_lut:
            source_m = source_lut[k].matrix_world
            dest_m = dest_lut[k].matrix_world
            transform = source_m @ dest_m.inverted()
            transform.freeze()
            vote_for(transform, dest_lut[k], votes)

    if not votes:
        print("Could not find any matching texture")
        return False
        
    matrix, voters = get_argmax_vote(votes)
    #print([len(v) for v in votes.values()])
    print("Applied transform:")
    print(matrix)
    
    duplicates = voters if remove_duplicates else []
    apply_matrix(dest_collec_name, matrix, duplicates)
    if duplicates:
        print("Removed {} redundant objects".format(len(duplicates)))
    return True


######################################################################


class OBJECT_OT_LilyCaptureMerger(bpy.types.Operator):
    """Merge collections of objects together based on the textures they share"""
    bl_idname = "object.lily_capture_merger"
    bl_label = "Lily Capture Merger"
    bl_options = {'REGISTER', 'UNDO'}

    remove_duplicates: bpy.props.BoolProperty(
        name="Remove duplicates",
        default=True,
        description="remove objects that are detected as matching in the two collections.",
    )

    @classmethod
    def poll(cls, context):
        return get_source_dest(context)[1] is not None

    def execute(self, context):
        source_obj, dest_obj = get_source_dest(context)
        merge_captures(source_obj, dest_obj, self.remove_duplicates)
        return {'FINISHED'}

def menu_func_import(self, context):
    self.layout.operator(OBJECT_OT_LilyCaptureMerger.bl_idname)

def register():
    bpy.utils.register_class(OBJECT_OT_LilyCaptureMerger)
    bpy.types.VIEW3D_MT_object.append(menu_func_import)

def unregister():
    bpy.utils.unregister_class(OBJECT_OT_LilyCaptureMerger)
    bpy.types.VIEW3D_MT_object.remove(menu_func_import)

if __name__ == "__main__":
    register()
