#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-only
#
# ifstool: A tool to monitor and report interference statistics (CONFIG_CGROUP_IFS).
#
# Copyright(c) 2026. Huawei Technologies Co., Ltd
#
# Authors:
#  Tengda Wu <wutengda2@huawei.com>

import os
import time
import re
import csv
import argparse
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.offline import get_plotlyjs


class InterferenceTool:
    """
    A utility class to monitor Linux cgroup interference statistics (CONFIG_CGROUP_IFS)
    and generate visual comparison reports between baseline and current data.
    """

    def __init__(self):
        # Default path for cgroup v2 interference stats
        self.base_path = "/sys/fs/cgroup"

    def parse_stat(self, content):
        """
        Parses the raw content of the interference.stat file.

        Args:
            content (str): Raw string content from the stat file.
        Returns:
            tuple: (total_times dict, distributions dict)
        """

        total_times, distributions = {}, {}

        # Split content into sections: Top-level totals and various distributions
        # Uses positive lookahead to split before a word followed by ' distribution'
        sections = re.split(r"\n(?=[a-z]+ distribution)", content)

        # Parse global total times (first section)
        for line in sections[0].strip().split("\n"):
            match = re.match(r"^([a-z]+)\s+(\d+)$", line.strip())
            if match:
                total_times[match.group(1)] = int(match.group(2))

        # Parse histogram distributions (subsequent sections)
        for section in sections[1:]:
            lines = section.strip().split("\n")
            # Extract header name (e.g., 'spinlock distribution')
            header = lines[0].replace(" distribution", "").strip()
            # Parse bucket key-value pairs (e.g., '[64 ns, 128 ns) : 143791')
            dist_data = {
                l.split(":")[0].strip(): int(l.split(":")[1].strip())
                for l in lines[1:]
                if ":" in l
            }
            distributions[header] = dist_data

        return total_times, distributions

    def monitor(self, cgroup_id, interval, duration, output_csv):
        """
        Periodically samples interference stats and saves results to a CSV file.

        Args:
            cgroup_id (str): The specific cgroup folder name. For example,
                /sys/kernel/cgroup/A/B/C, when want to monitor level B, the
                cgroup_id is A/B.
            interval (float): Seconds between samples.
            duration (int): Total monitoring time in seconds.
            output_csv (str): Filename to save sampled data.
        """

        candidates = [
            os.path.join(self.base_path, cgroup_id),
            os.path.join(self.base_path, "cpu", cgroup_id),  # cgroup v1
        ]
        path = next((p for p in candidates if os.path.exists(p)), None)

        if not path:
            print(
                f"\n[!] No access to cgroup: {os.path.join(self.base_path, cgroup_id)}"
            )
            print("    Hint: You can find the correct cid by executing:")
            print("          cat /proc/<PID>/cgroup")
            return

        path = os.path.join(path, "interference.stat")
        if not os.path.exists(path):
            print(f"\n[!] Interface file '{path}' not found.")
            print("    This usually happens due to one of the following:")
            print("    1. Kernel is not compiled with CONFIG_CGROUP_IFS=y")
            print(
                "    2. CONFIG_CGROUP_IFS_DEFAULT_ENABLED is not set to y",
                "and the boot parameter 'cgroup_ifs=1' is not added",
            )
            print("    3. The current cgroup is not managed by the IFS controller")
            return

        print(
            f"[*] Starting monitor: {cgroup_id}, interval: {interval}s, duration: {duration}s"
        )

        data_list = []
        start_time = time.time()

        try:
            while (time.time() - start_time) < duration:
                ts = time.strftime("%H:%M:%S")
                # Using fractional seconds in timestamp for sub-second intervals
                if interval < 1.0:
                    ts = time.strftime("%H:%M:%S") + f".{int((time.time()%1)*100):02d}"

                with open(path, "r") as f:
                    total_times, dists = self.parse_stat(f.read())

                # Append total time metrics
                for cat, val in total_times.items():
                    data_list.append([ts, cgroup_id, cat, "total_time_ns", val])

                # Append bucket distribution metrics
                for cat, dist in dists.items():
                    for b, c in dist.items():
                        data_list.append([ts, cgroup_id, cat, f"bucket_{b}", c])

                time.sleep(interval)
        except KeyboardInterrupt:
            print("\n[!] Monitoring interrupted by user.")

        # Write buffered data to CSV
        with open(output_csv, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(
                ["timestamp", "cgroup_id", "category", "metric_type", "value"]
            )
            writer.writerows(data_list)
        print(f"[+] Data exported successfully: {output_csv}")

    def bucket_key(self, bucket_str):
        """
        Parsing logic for histogram bucket labels to enable correct numerical sorting.
        Example input: "bucket_[67.10 ms, 134.21 ms)"

        Returns:
            float: The value converted to nanoseconds (ns).
        """
        try:
            # Regex to capture the first number and its unit (ns|us|ms|s)
            match = re.search(r"(\d+\.?\d*)\s*(ns|us|ms|s)", bucket_str)
            if not match:
                return 0

            value = float(match.group(1))
            unit = match.group(2).lower()

            # Conversion factors to nanoseconds
            factors = {"ns": 1, "us": 1000, "ms": 1000000, "s": 1000000000}

            return value * factors.get(unit, 1)
        except Exception:
            return 0

    def report(self, base_csv, curr_csv, output_html):
        """
        Loads two CSV files and generates an interactive HTML dashboard.

        Args:
            base_csv (str): Path to baseline data.
            curr_csv (str): Path to current data.
            output_html (str): Path to generate the HTML report.
        """

        df_b = pd.read_csv(base_csv)
        single_mode = curr_csv is None or base_csv == curr_csv
        df_c = None if single_mode else pd.read_csv(curr_csv)

        # Get unique union of categories present in both datasets
        all_dfs = [df_b] if single_mode else [df_b, df_c]
        categories = sorted(
            list(set().union(*(df.category.unique() for df in all_dfs)))
        )

        # Color mapping to ensure consistent colors for categories across plots
        color_sequence = [
            "#636EFA",
            "#EF553B",
            "#00CC96",
            "#AB63FA",
            "#FFA15A",
            "#19D3F3",
            "#FF6692",
            "#B6E880",
            "#FF97FF",
            "#FECB52",
        ]
        color_map = {
            cat: color_sequence[i % len(color_sequence)]
            for i, cat in enumerate(categories)
        }

        # Build HTML content with CSS for layout
        html_parts = [
            "<html><head><title>IFS Analysis Dashboard</title>",
            f"<script type='text/javascript'>{get_plotlyjs()}</script>",
        ]
        col_width = "100%" if single_mode else "49.5%"
        html_parts.append(
            f"""<style>
            body{{font-family:'Segoe UI',sans-serif; background:#f0f2f5; padding:20px; color:#333;}}
            .card{{background:white; padding:15px; margin-bottom:20px; border-radius:10px; box-shadow:0 2px 5px rgba(0,0,0,0.05);}}
            .row-container{{display:flex; gap:15px; justify-content:space-between;}}
            .col-item{{width:{col_width}; min-width:0;}}
            .stat-box{{display:flex; gap:10px; margin:5px 0; font-size:0.85em; color:#555;}}
            .stat-item{{background:#f1f3f5; padding:4px 8px; border-radius:4px; border-left:3px solid #007bff;}}
            .section-header{{margin:30px 0 15px 0; padding-bottom:10px; border-bottom:2px solid #ddd;}}
        </style></head><body>"""
        )

        title = "Performance Analysis" if single_mode else "Performance Comparison"
        html_parts.append(f"<h2>{title} (CONFIG_CGROUP_IFS)</h2>")
        info = (
            f"File: {base_csv}"
            if single_mode
            else f"Baseline: {base_csv} | Current: {curr_csv}"
        )
        html_parts.append(f"<p style='color:#666;'>{info}</p>")

        # --- Part 1: Total Time Trends (Line Charts) ---
        html_parts.append("<div class='card'>")
        html_parts.append("<h3 style='margin-top:0;'>Total Time Delta (ns) Trend</h3>")

        # Subplots ensure Y-axis can be matched for direct visual comparison
        fig_line = make_subplots(
            rows=1,
            cols=1 if single_mode else 2,
            horizontal_spacing=0.05,
            subplot_titles=("Latency",) if single_mode else ("Baseline", "Current"),
        )

        plot_configs = (
            [("Data", df_b)] if single_mode else [("Baseline", df_b), ("Current", df_c)]
        )
        for i, (name, df) in enumerate(plot_configs, 1):
            for cat in categories:
                sub = df[
                    (df["category"] == cat) & (df["metric_type"] == "total_time_ns")
                ].sort_values("timestamp")
                if not sub.empty:
                    y_val = sub["value"].diff().fillna(0)
                    fig_line.add_trace(
                        go.Scatter(
                            x=sub["timestamp"],
                            y=y_val,
                            name=f"{cat} ({name})",
                            legendgroup=cat,
                            mode="lines+markers",
                            line=dict(color=color_map[cat], width=2),
                            marker=dict(color=color_map[cat], size=6),
                        ),
                        row=1,
                        col=i,
                    )

        # Sync Y-axes scale for baseline and current plots
        if not single_mode:
            fig_line.update_yaxes(matches="y", row=1, col=2)
        fig_line.update_layout(
            height=500, template="plotly_white", margin=dict(t=50, b=20)
        )
        html_parts.append(fig_line.to_html(full_html=False, include_plotlyjs=False))
        html_parts.append("</div>")

        html_parts.append(
            "<h3 class='section-header'>Detailed Latency Distribution</h3>"
        )

        # --- Part 2: Distribution Heatmaps ---
        for cat in categories:

            def get_m_and_stats(df):
                """Processes raw metrics into a delta-count matrix for heatmap visualization."""
                sub = df[
                    (df["category"] == cat)
                    & (df["metric_type"].str.startswith("bucket_"))
                ]
                if sub.empty:
                    return pd.DataFrame(), {}

                # Transform to matrix (Buckets vs Time)
                p = sub.pivot(
                    index="metric_type", columns="timestamp", values="value"
                ).sort_index(axis=1)
                # Sort Y-axis buckets based on numerical time value
                p = p.reindex(sorted(p.index, key=self.bucket_key))
                # Calculate incremental change (delta)
                delta = p.diff(axis=1).fillna(0)
                return delta, {
                    "Total": int(delta.values.sum()),
                    "Peak": int(delta.values.max()),
                }

            m_b, stats_b = get_m_and_stats(df_b)
            m_c, stats_c = (None, None) if single_mode else get_m_and_stats(df_c)

            # Determine global max for heatmap color scaling consistency
            z_vals = [m_b.values.max() if not m_b.empty else 0]
            if not single_mode:
                z_vals.append(m_c.values.max() if not m_c.empty else 0)
            global_max_z = max(z_vals + [1])

            html_parts.append(f"<div class='card'><h4>Category: {cat.upper()}</h4>")
            html_parts.append("<div class='row-container'>")

            configs = (
                [("Data", m_b, stats_b)]
                if single_mode
                else [("Baseline", m_b, stats_b), ("Current", m_c, stats_c)]
            )
            for name, m, stats in configs:
                html_parts.append("<div class='col-item'>")
                if not m.empty:
                    # Display summary statistics
                    html_parts.append("<div class='stat-box'>")
                    for k, v in stats.items():
                        html_parts.append(
                            f"<div class='stat-item'><b>{k}:</b> {v}</div>"
                        )
                    html_parts.append("</div>")

                    # Generate Heatmap
                    fig = go.Figure(
                        data=go.Heatmap(
                            z=m.values,
                            x=m.columns,
                            y=m.index,
                            colorscale="Viridis",
                            zmin=0,
                            zmax=global_max_z,
                            colorbar=dict(title="Count", thickness=10, len=0.7),
                        )
                    )
                    fig.update_layout(
                        height=400,
                        margin=dict(l=120, r=0, t=10, b=30),
                        template="plotly_white",
                    )
                    html_parts.append(
                        fig.to_html(full_html=False, include_plotlyjs=False)
                    )
                else:
                    html_parts.append(
                        "<div style='height:100px; display:flex; align-items:center; justify-content:center; background:#fafafa; color:#ccc;'>No Data</div>"
                    )
                html_parts.append("</div>")

            html_parts.append("</div></div>")  # end row-container & card

        # Save the final HTML report
        with open(output_html, "w") as f:
            f.writelines(html_parts + ["</body></html>"])
        print(f"[+] Comparison report generated: {output_html}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    sub = parser.add_subparsers(dest="mode")

    # Monitor Mode Arguments
    p_mon = sub.add_parser("monitor")
    p_mon.add_argument(
        "-G",
        "--cgroup",
        required=True,
        help="Cgroup name (e.g. 'docker/cid'). "
        "Tip: Find this via 'cat /proc/<PID>/cgroup'",
    )
    p_mon.add_argument(
        "-i", "--interval", type=float, default=1.0, help="Sampling interval (sec)"
    )
    p_mon.add_argument(
        "-d", "--duration", type=int, default=30, help="Sampling duration (sec)"
    )
    p_mon.add_argument("-o", "--output", default="capture.csv")

    # Report Mode Arguments
    p_comp = sub.add_parser("report")
    p_comp.add_argument("-b", "--base", required=True, help="Baseline CSV file")
    p_comp.add_argument(
        "-c", "--curr", default=None, help="Current CSV file (optional, for comparison)"
    )
    p_comp.add_argument("-o", "--output", default="report.html")

    args = parser.parse_args()
    tool = InterferenceTool()

    if args.mode == "monitor":
        tool.monitor(args.cgroup, args.interval, args.duration, args.output)
    elif args.mode == "report":
        tool.report(args.base, args.curr, args.output)
    else:
        parser.print_help()
