Skip to content
Snippets Groups Projects
configure_gns3_from_scratch.py 20.92 KiB
from scapy.all import *
from scapy.layers.inet import IP
from scapy.layers.smb import SMB_Header
import networkx as nx
import matplotlib.pyplot as plt
from gns3manager import *
import gns3manager
from configuration_manager import ConfigurationManager
from ubuntu_base import vc_server
import json

DIR = "/home/anh/Documents/kernel_preparation/gns3_unikernel_testbed"


class GraphGenerator:
    map_connection_ubuntu_base = {}

    @staticmethod
    def generate_topology_layout(num_window_hosts, num_ubuntu_base, numb_ubuntu_intermediate, connections, scale=1000):
        """
        Generates (x, y) coordinates for a network topology.

        Args:
            :param scale:
            :param connections:
            :param numb_ubuntu_intermediate:
            :param num_ubuntu_base:
            :param num_window_hosts:
        Returns:
            dict: A dictionary of node positions {node: (x, y)}.

        """
        # Create a graph
        G = nx.Graph()

        # Add switches and hosts
        window_hosts = [f"pc_{i}" for i in range(0, num_window_hosts)]
        ubuntu_bases = [f"ub_{i}" for i in range(0, num_ubuntu_base)]
        ubuntu_intermediates = [f"ui_{i}" for i in range(0, numb_ubuntu_intermediate)]

        G.add_nodes_from(window_hosts)
        G.add_nodes_from(ubuntu_bases)
        G.add_nodes_from(ubuntu_intermediates)

        # Add connections
        G.add_edges_from(connections)

        for u, v in G.edges:
            if "pc_" in u and "ui_" in v or "pc_" in v and "ui_" in u:
                G[u][v]["weight"] = 20.0  # Strong pull
            elif "ub_" in v and "ui_" in u or "ui_" in v and "ub_" in u:
                G[u][v]["weight"] = 10.0  # Strong pull
            else:
                G[u][v]["weight"] = 1.0  # Default pull

        # Generate positions with spring layout using weights
        positions = nx.spring_layout(G, scale=scale, weight="weight")

        # Convert positions to integers for GNS3

        # num_window_hosts, num_ubuntu_base, numb_ubuntu_intermediate
        int_positions = {
            node: {"x": int(pos[0]), "y": int(pos[1])} for node, pos in positions.items()
        }

        pc_loc_lst = [{}] * num_window_hosts
        ub_loc_lst = [{}] * num_ubuntu_base
        ui_loc_lst = [{}] * numb_ubuntu_intermediate
        for key in int_positions:
            position = int_positions[key]
            name = key.split("_")[0]
            index_temp = int(key.split("_")[1])
            if "pc" in name:
                pc_loc_lst[index_temp] = position
            elif "ub" in name:
                ub_loc_lst[index_temp] = position
            elif "ui" in name:
                ui_loc_lst[index_temp] = position
        return int_positions, pc_loc_lst, ub_loc_lst, ui_loc_lst

    @staticmethod
    def plot_topology(layout, connections):
        """
        Plots the topology based on the generated layout.

        Args:
            layout (dict): Node positions {node: {'x': x, 'y': y}}.
            connections (list): List of connections between nodes.
        """
        plt.figure(figsize=(10, 8))

        # Plot the nodes
        for node, coords in layout.items():
            x, y = coords["x"], coords["y"]
            plt.scatter(x, y, label=node, s=100, alpha=0.8)
            plt.text(x + 20, y, node, fontsize=10, ha="left", va="center")

        # Plot the connections
        for connection in connections:
            node1, node2 = connection
            x1, y1 = layout[node1]["x"], layout[node1]["y"]
            x2, y2 = layout[node2]["x"], layout[node2]["y"]
            plt.plot([x1, x2], [y1, y2], "k-", alpha=0.6)

        # Customize the plot
        plt.title("Network Topology Layout")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.axhline(0, color="black", linewidth=0.5, alpha=0.7)
        plt.axvline(0, color="black", linewidth=0.5, alpha=0.7)
        plt.gca().invert_yaxis()
        plt.legend()
        plt.tight_layout()

        # Show the plot
        plt.show()

    @staticmethod
    def generate_location(num_window_hosts, num_ubuntu_base, numb_ubuntu_intermediate, conf_strs):
        connections = []
        for conf_str in conf_strs:
            conn_cmd_split = conf_str.split(";")
            for conn_cmd in conn_cmd_split:
                nodes = conn_cmd.split(",")
                node1_index, node1_type = ConfigurationManager.extract_component_id(nodes[0])
                node2_index, node2_type = ConfigurationManager.extract_component_id(nodes[1])

                if node1_type in ["pc", "ub", "ui"] and node2_type in ["pc", "ub", "ui"]:
                    for i_node_1 in node1_index:
                        node_1 = node1_type + "_" + str(i_node_1)
                        for i_node_2 in node2_index:
                            node_2 = node2_type + "_" + str(i_node_2)
                            connections.append((node_1, node_2))

        layout, pc_loc_lst, ub_loc_lst, ui_loc_lst = GraphGenerator.generate_topology_layout(num_window_hosts, num_ubuntu_base, numb_ubuntu_intermediate, connections)
        print("Generated Layout:", layout)

        GraphGenerator.plot_topology(layout, connections)

        return pc_loc_lst, ub_loc_lst, ui_loc_lst

    # @staticmethod
    # def generate_random_ip():
    #     # Generate four random numbers between 0 and 255
    #     ip_parts = [10, 0, 0, 0]
    #     while ip_parts[0] == 10:
    #         ip_parts = [str(random.randint(0, 255)) for _ in range(4)]
    #         # Join the parts with dots to form a valid IP address
    #     return ".".join(ip_parts)

    @staticmethod
    def map_connection_to_ip_address(host1_id, host2_id, host1_port, host1_adapter, host2_port, host2_adapter, ip_addresses):
        node_1_js = GNS3Manager.get_node(host1_id)
        node_2_js = GNS3Manager.get_node(host2_id)

        node_1_ip = ip_addresses.split(",")[0]
        node_2_ip = ip_addresses.split(",")[1]

        if node_1_js["ports"][0]["mac_address"] in GraphGenerator.map_connection_ubuntu_base:
            GraphGenerator.map_connection_ubuntu_base[node_1_js["ports"][0]["mac_address"]][host1_adapter] = node_1_ip
        else:
            GraphGenerator.map_connection_ubuntu_base[node_1_js["ports"][0]["mac_address"]] = {host1_adapter: node_1_ip}

        if node_2_js["ports"][0]["mac_address"] in GraphGenerator.map_connection_ubuntu_base:
            GraphGenerator.map_connection_ubuntu_base[node_2_js["ports"][0]["mac_address"]][host2_adapter] = node_2_ip
        else:
            GraphGenerator.map_connection_ubuntu_base[node_2_js["ports"][0]["mac_address"]] = {host2_adapter: node_2_ip}


# def duplicate_project(project_name):
#     global PROJECT_ID, username, password
#     url = "http://localhost:3080/v2/projects/" + PROJECT_ID + "/duplicate"
#     data = {"name": project_name.replace(".gns3", "") + "-2"}
#
#     response = requests.post(url, json=data, auth=(username, password))
#     if response.status_code == 201:
#         json_data = response.json()
#         for project in json_data:
#             if project["status"] == "opened":
#                 print(project["filename"])
#                 return project["project_id"], project["filename"]
#     else:
#         print(f"Error duplicate_project: {response.status_code}")


def process_packet(pkt):
    global array_point, continue_sniffing, malware_instance, is_infected, map_infecting, map_infector
    if SMB_Header in pkt:
        if not is_infected and pkt[SMB_Header].Command == 0xa0 and pkt[SMB_Header].Flags == 0x18:
            victim_ip = pkt[IP].dst
            infecting_ip = pkt[IP].src
            if victim_ip not in map_infecting.keys():
                map_infecting[victim_ip] = [infecting_ip]
            elif infecting_ip not in map_infecting[victim_ip]:
                map_infecting[victim_ip].append(infecting_ip)
            output_file = "infecting.txt"
            with open(output_file, "w") as f:
                json.dump(map_infecting, f)

        if (malware_instance == 0 and pkt[SMB_Header].MID == 81) or (
                malware_instance == 1 and pkt[SMB_Header].Reserved == 0x0011):
            array_point.append(pkt[IP].src + "," + pkt[IP].dst)
            if pkt[IP].src not in map_infector:
                map_infector[pkt[IP].src] = 1
            else:
                input("Error" + pkt[IP].src)
            print("saved", pkt[IP].src, " - ", pkt[IP].dst, "- size", len(array_point))

            output_file = "infector.txt"
            with open(output_file, "w") as f:
                f.write(str(array_point))
            continue_sniffing = False
            is_infected = True


def run_vc_server():
    map_json = json.dumps(GraphGenerator.map_connection_ubuntu_base)
    command_str = f"python3 {DIR}/ubuntu_base/vc_server.py '{map_json}'"
    process = subprocess.Popen(
        ['xterm', '-hold', '-e', command_str],
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
        stdin=subprocess.DEVNULL,
        start_new_session=True
    )


if __name__ == "__main__":
    gns3manager.PROJECT_ID, project_name = GNS3Manager.get_projectId()
    user_input = input("Press enter if Correct project to rennder?")
    if user_input != '':
        sys.exit()

    if len(sys.argv) > 1:
        command = sys.argv[1]
        if command == "1":
            with open(sys.argv[2], 'r') as file:
                lines = file.readlines()
            gns3manager.malware_instance = 0 if lines[0].strip() == "wannacry" else 1
            NUM_SWITCH = int(lines[1].strip())  # Number of switches
            NUM_HOST = int(lines[2].strip())  # Number of PCs
            NUM_DHCP_UNIKERNEL = int(lines[3].strip())  # Number of dhcp-unikernel
            NUM_UBUNTU_BASE = int(lines[4].strip())
            NUM_UBUNTU_INTERMEDIATE = int(lines[5].strip())
            is_default = int(lines[6].strip())
            index_line = 7
            configuration_strings = []
            for index_line in range(index_line, len(lines)):
                configuration_strings.append(lines[index_line].strip())  # Configuration string
            print("Number of switches:", NUM_SWITCH)
            print("Number of PCs:", NUM_HOST)
            print("Number of DHCP unikernel:", NUM_DHCP_UNIKERNEL)
            print("Number of ubuntu base:", NUM_UBUNTU_BASE)
            print("Number of ubuntu intermediate:", NUM_UBUNTU_INTERMEDIATE)
            print("Is default configuration:", is_default)
            print("Configuration string:", configuration_strings)

            # Verify the strings
            ConfigurationManager.verify_strings(configuration_strings)

            # Read additional file containing IP addresses
            mapping_ip_connections = {}
            with open(sys.argv[3], 'r') as file:
                for line in file:
                    key, value = line.strip().split(":")
                    mapping_ip_connections[key] = value

            array_existing_nodes = GNS3Manager.get_nodes()
            array_window_hosts = []
            array_switches = []
            array_dhcp_unikernel = []
            array_ubuntu_base = []
            array_ubuntu_intermediate = []
            router_0 = Node(1, 1, "", ROUTER_TYPE)
            ubuntu_server = Node(1, 1, GNS3Manager.create_ubuntu_server(1200, 1000), UBUNTU_TYPE)
            ubuntu_vc_server = Node(1, 1, "", UBUNTU_VC_SERVER_TYPE)

            for exist_node in array_existing_nodes:
                if exist_node[1] == "R1":
                    router_0.id = exist_node[0]
                    x_node = exist_node[2]
                    y_node = exist_node[3]
                    z_node = exist_node[4]
                elif exist_node[1] == "NAT1":
                    ubuntu_vc_server.id = exist_node[0]
                    x_node = exist_node[2]
                    y_node = exist_node[3]
                    z_node = exist_node[4]

            # # Generate the layout
            pc_location_lst, ub_location_lst, ui_location_lst = GraphGenerator.generate_location(NUM_HOST, NUM_UBUNTU_BASE, NUM_UBUNTU_INTERMEDIATE, configuration_strings)

            # Create switches
            x_local = -1200
            y_local = -1000
            step = 100
            for i in range(NUM_SWITCH):
                x_temp = x_local
                y_temp = y_local + step * i
                switch_id = GNS3Manager.create_switch(MAX_SIZE_OF_SWITCH_PORT, x_temp, y_temp, 1)
                array_switches.append(Node(MAX_SIZE_OF_SWITCH_PORT, 1, switch_id, SWITCH_TYPE))

            # Create Windows hosts
            if NUM_HOST > 0:
                array_window_hosts.append(Node(1, 1, GNS3Manager.create_window(pc_location_lst[0]["x"], pc_location_lst[0]["y"], 1), WINDOW_TYPE))
                for i in range(1, NUM_HOST):
                    x_temp = pc_location_lst[i]["x"]
                    y_temp = pc_location_lst[i]["y"]
                    array_window_hosts.append(Node(1, 1, GNS3Manager.create_window(x_temp, y_temp, 0), WINDOW_TYPE))

            # create ubuntu base
            for i in range(NUM_UBUNTU_BASE):
                x_temp = ub_location_lst[i]["x"]
                y_temp = ub_location_lst[i]["y"]
                array_ubuntu_base.append(Node(1, MAX_SIZE_OF_UBUNTU_BASE_PORT, GNS3Manager.create_ubuntu_base(x_temp, y_temp), UBUNTU_BASE_TYPE))

            # create dhcp unikernel
            x_local = -1400
            y_local = -1000
            step = 100
            for i in range(NUM_DHCP_UNIKERNEL):
                x_temp = x_local
                y_temp = y_local + step * i
                array_dhcp_unikernel.append(
                    Node(1, 1, GNS3Manager.create_unikernel(x_temp, y_temp, DHCP_UNIKERNEL_TYPE), DHCP_UNIKERNEL_TYPE))

            # create ubuntu intermediate
            for i in range(NUM_UBUNTU_INTERMEDIATE):
                x_temp = ui_location_lst[i]["x"]
                y_temp = ui_location_lst[i]["y"]
                array_ubuntu_intermediate.append(Node(1, MAX_SIZE_OF_UBUNTU_BASE_PORT, GNS3Manager.create_ubuntu_intermediate(x_temp, y_temp), UBUNTU_INTERMEDIATE_TYPE))

            if is_default:
                configuration_strings = ["sw_0,pc_0-49;sw_1,pc_50-99;sw,sw;sw_0,r_0;sw_0,u_0;ub_0,ub_1"]
            else:
                if not configuration_strings:
                    sys.exit()

            for configuration_string in configuration_strings:
                connection_cmd_split = configuration_string.split(";")
                for connection_cmd in connection_cmd_split:
                    components = connection_cmd.split(",")
                    if components[0] == "sw" and components[1] == "sw":
                        for i in range(NUM_SWITCH - 1):
                            switch_1 = array_switches[i]
                            switch_2 = array_switches[i + 1]
                            switch_1.connect(switch_2)
                    else:
                        host1_index, host1_type = ConfigurationManager.extract_component_id(components[0])
                        host2_index, host2_type = ConfigurationManager.extract_component_id(components[1])
                        for index1 in host1_index:
                            host1 = ConfigurationManager.get_host(index1, host1_type,
                                                                  ubuntu_server, router_0, array_window_hosts, array_switches, array_dhcp_unikernel, array_ubuntu_base,
                                                                  array_ubuntu_intermediate, ubuntu_vc_server)
                            for index2 in host2_index:
                                host2 = ConfigurationManager.get_host(index2, host2_type,
                                                                      ubuntu_server, router_0, array_window_hosts, array_switches, array_dhcp_unikernel, array_ubuntu_base,
                                                                      array_ubuntu_intermediate, ubuntu_vc_server)
                                print(host1_type, index1, "connect", host2_type, index2)
                                host1_port, host1_adapter, host2_port, host2_adapter = host1.connect(host2)
                                key_map = f"{host1_type}_{index1},{host2_type}_{index2}"
                                if key_map in mapping_ip_connections:
                                    GraphGenerator.map_connection_to_ip_address(host1.id, host2.id, host1_port, host1_adapter, host2_port, host2_adapter,
                                                                                mapping_ip_connections[key_map])
                                key_map = f"{host2_type}_{index2},{host1_type}_{index1}"
                                if key_map in mapping_ip_connections:
                                    GraphGenerator.map_connection_to_ip_address(host1.id, host2.id, host1_port, host1_adapter, host2_port, host2_adapter,
                                                                                mapping_ip_connections[key_map])

            # Press y to start Ubuntu VC server and dhcp
            # if input("Start DHCP (y/N)?") != "y":
            #     sys.exit()
            # GNS3Manager.start_node(ubuntu_vc_server.id)
            # for host in array_dhcp_unikernel:
            #     GNS3Manager.start_node(host.id)

            if input("Start Ubuntu VC Server (y/N)?") != "y":
                sys.exit()
            threading.Thread(target=run_vc_server, daemon=True).start()

            # Press y to start Ubuntu bases
            if input("Start Ubuntu Switch nodes (y/N)?") != "y":
                sys.exit()
            for host in array_ubuntu_base:
                time.sleep(0.2)
                GNS3Manager.start_node(host.id)
            for host in array_ubuntu_intermediate:
                time.sleep(0.2)
                GNS3Manager.start_node(host.id)

            # Enter y to start router R1 and ubuntu server
            if input("Start Ubuntu Server and Router (y/N)?") != "y":
                sys.exit()
            GNS3Manager.start_node(router_0.id)
            GNS3Manager.start_node(ubuntu_server.id)

            # Capture packets from communications/links
            if input("Capture link (y/N)?") == "y":
                array_links = GNS3Manager.get_links()
                for link in array_links:
                    if ubuntu_server.id not in link and router_0.id not in link and array_window_hosts[0].id not in link:
                        GNS3Manager.start_capture_link(link[0])

            # Press y to start Windows PCs
            if input("Start Windows PCs (y/N)?") != "y":
                sys.exit()
            for host in array_window_hosts:
                time.sleep(1)
                GNS3Manager.start_node(host.id)

            # TCP listen on port 12345 to collect propagation time data
            command = "nc -l -p 12345 > measurement.txt"
            try:
                subprocess.run(command, shell=True, check=True)
                print("Command executed successfully.")
                GNS3Manager.nocapture()  # Stop capturing packets after receiving data
                # Store data in a .zip
                subprocess.run("zip a.zip ~/GNS3/projects/" + project_name.replace(".gns3", "") + "/project-files/captures/*", shell=True, check=True)
            except subprocess.CalledProcessError as e:
                print(f"Error executing command: {e}")
        elif command == "node":
            my_node_id = sys.argv[2]
            node_list = GNS3Manager.get_nodes()
            for n in node_list:
                if my_node_id in n[0]:
                    print(n[1])
            sys.exit()
        # elif command == "duplicate":
        #     PROJECT_ID = sys.argv[2]
        #     duplicate_project(project_name)
        #     sys.exit()
        elif command == "capture":
            array_links = GNS3Manager.get_links()
            for link in array_links:
                GNS3Manager.start_capture_link(link[0])
            sys.exit()
        elif command == "nocapture":
            GNS3Manager.nocapture()
        elif command == "read":
            array_point = []
            map_infecting = {}
            map_infector = {}
            continue_sniffing = True
            is_infected = False
            malware_instance = 0 if sys.argv[2] == "wannacry" else 1

            # folder_path_capture = "/home/anh/GNS3/projects/" + project_name.replace(".gns3","") + "/project-files/captures/"
            folder_path_capture = "/mnt/raid5/projects/" + project_name.replace(".gns3", "") + "/project-files/captures/"
            files = os.listdir(folder_path_capture)
            pcap_files = [(folder_path_capture + file) for file in files if file.endswith('.pcap')]
            for pcap_file in pcap_files:
                continue_sniffing = True
                is_infected = False
                sniff(offline=pcap_file, filter="port 445", prn=process_packet, stop_filter=lambda _: not continue_sniffing, verbose=False)
            sys.exit()
# else:
# This is for testing a single function
# my_templates = GNS3Manager.get_templates()
# GNS3Manager.create_switch(MAX_SIZE_OF_SWITCH_PORT,0,0,1)
# GNS3Manager.create_ubuntu_base(0, 0)
# malware_instance = 0
# GNS3Manager.create_window(0, 0, 0)
# GNS3Manager.create_unikernel(0, 0, DHCP_UNIKERNEL_TYPE)
# GNS3Manager.create_ubuntu_server(0, 0)