# Copyright (C) 2021-2024 Université Gustave Eiffel.
# Copyright (C) 2025-2026 Université Gustave Eiffel, INRIA.
# This file is part of the EasyFEA project.
# EasyFEA is distributed under the terms of the GNU General Public License v3, see LICENSE.txt and CREDITS.md for more information.

"""
Homog1
======

Conduct homogenization using an example outlined in `Computational Homogenization of Heterogeneous Materials with Finite Elements <https://doi.org/10.1007/978-3-030-18383-7>`_.

Section 4.7 with corrected values on page 89 (Erratum).
"""
# sphinx_gallery_thumbnail_number = -1

import matplotlib.pyplot as plt
import numpy as np

from EasyFEA import Display, Models, ElemType, Simulations
from EasyFEA.Geoms import Points, Circle
from EasyFEA.FEM import LagrangeCondition, FeArray
from typing import Optional


def Compute_ukl(
    simu: Simulations.Elastic,
    nodes_kubc: np.ndarray,
    Ekl: np.ndarray,
    paired_nodes: Optional[np.ndarray] = None,
    pltSol=False,
    useMean0=False,
):
    simu.Bc_Init()
    mesh = simu.mesh
    coord = mesh.coord

    usePBC = paired_nodes is not None

    def func_ux(x, y, z):
        return Ekl.dot([x, y])[0]

    def func_uy(x, y, z):
        return Ekl.dot([x, y])[1]

    directions = ["x", "y"]
    simu.add_dirichlet(nodes_kubc, [func_ux, func_uy], directions)

    if usePBC:
        for n0, n1 in paired_nodes:
            nodes = np.array([n0, n1])

            delta = (coord[n0] - coord[n1])[:2]
            values = Ekl @ delta

            for d, direction in enumerate(directions):
                dofs = simu.Bc_dofs_nodes(nodes, [direction])

                condition = LagrangeCondition(
                    "elastic", nodes, dofs, [direction], [values[d]], [1, -1]
                )
                simu._Bc_Add_Lagrange(condition)

    if useMean0:
        nodes = mesh.nodes
        vect = np.ones(mesh.Nn) * 1 / mesh.Nn
        for direction in directions:
            # sum d_i / Nn = 0
            dofs = simu.Bc_dofs_nodes(nodes, [direction])
            condition = LagrangeCondition(
                "elastic", nodes, dofs, [direction], [0], [vect]
            )
            simu._Bc_Add_Lagrange(condition)

    ukl = simu.Solve()

    simu.Save_Iter()

    if pltSol:
        Display.Plot_Result(simu, "ux", deformFactor=0.3)
        Display.Plot_Result(simu, "uy", deformFactor=0.3)

        Display.Plot_Result(simu, "Sxx", deformFactor=0.3)
        Display.Plot_Result(simu, "Syy", deformFactor=0.3)
        Display.Plot_Result(simu, "Sxy", deformFactor=0.3)

    return ukl


if __name__ == "__main__":
    Display.Clear()

    # ----------------------------------------------
    # Configuration
    # ----------------------------------------------

    # use Periodic boundary conditions ?
    usePBC = True

    # ----------------------------------------------
    # Mesh
    # ----------------------------------------------
    meshSize = 1 / 15

    p0 = (-1 / 2, -1 / 2)
    p1 = (1 / 2, -1 / 2)
    p2 = (1 / 2, 1 / 2)
    p3 = (-1 / 2, 1 / 2)
    pts = [p0, p1, p2, p3]
    contour = Points(pts, meshSize)

    # inclusion
    f = 0.4
    r = 1 * np.sqrt(f / np.pi)
    inclusion = Circle((0, 0), 2 * r, meshSize, isFilled=True)

    mesh = contour.Mesh_2D([inclusion], ElemType.TRI6)

    Display.Plot_Mesh(mesh, title="RVE")

    if usePBC:
        nodes_kubc = mesh.Nodes_Tags(["P0", "P1", "P2", "P3"])
        paired_nodes = mesh.Get_Paired_Nodes(nodes_kubc, True)
    else:
        nodes_kubc = mesh.Nodes_Tags(["L0", "L1", "L2", "L3"])
        paired_nodes = None

    # ----------------------------------------------
    # Material and Simulation
    # ----------------------------------------------
    elements_inclusion = mesh.Elements_Tags(["S1"])
    elements_matrix = mesh.Elements_Tags(["S0"])

    E = np.zeros_like(mesh.groupElem.elements, dtype=float)
    v = np.zeros_like(mesh.groupElem.elements, dtype=float)

    E[elements_matrix] = 1  # MPa
    v[elements_matrix] = 0.45

    if elements_inclusion.size > 0:
        E[elements_inclusion] = 50
        v[elements_inclusion] = 0.3

    material = Models.Elastic.Isotropic(2, E, v, planeStress=False)

    simu = Simulations.Elastic(mesh, material)

    Display.Plot_Result(simu, E, nodeValues=False, title="E [MPa]")
    Display.Plot_Result(simu, v, nodeValues=False, title="v")

    # ----------------------------------------------
    # Homogenization
    # ----------------------------------------------
    r2 = np.sqrt(2)
    E11 = np.array([[1, 0], [0, 0]])
    E22 = np.array([[0, 0], [0, 1]])
    E12 = np.array([[0, 1 / r2], [1 / r2, 0]])

    u11 = Compute_ukl(simu, nodes_kubc, E11, paired_nodes)
    u22 = Compute_ukl(simu, nodes_kubc, E22, paired_nodes)
    u12 = Compute_ukl(simu, nodes_kubc, E12, paired_nodes, True)

    u11_e = mesh.Locates_sol_e(u11, asFeArray=True)
    u22_e = mesh.Locates_sol_e(u22, asFeArray=True)
    u12_e = mesh.Locates_sol_e(u12, asFeArray=True)

    # ----------------------------------------------
    # Effective elasticity tensor (C_hom)
    # ----------------------------------------------
    U_e = FeArray.zeros(*u11_e.shape, 3)

    U_e[..., 0] = u11_e
    U_e[..., 1] = u22_e
    U_e[..., 2] = u12_e

    matrixType = "mass"
    wJ_e_pg = mesh.Get_weightedJacobian_e_pg(matrixType)
    B_e_pg = mesh.Get_B_e_pg(matrixType)

    C_Mat = Models.Reshape_variable(material.C, *B_e_pg.shape[:2])

    # Be careful here you have to use all the area even if there are holes
    # if you use the mesh area, multiply C_hom by the porosity (1-f)
    C_hom = (wJ_e_pg * C_Mat @ B_e_pg @ U_e).sum((0, 1)) / mesh.area

    if not inclusion.isFilled and mesh.area != 1:
        C_hom *= 1 - f

    # Display.Plot_BoundaryConditions(simu)

    print(f"f = {f}")
    print(f"c1111 = {C_hom[0, 0]}")
    print(f"c1122 = {C_hom[0, 1]}")
    print(f"c1212 = {C_hom[2, 2] / 2}")

    plt.show()
