update analyze codes

This commit is contained in:
luffy06 2024-01-16 17:31:18 +08:00
parent 6215c33a2b
commit 45c292cf28

View file

@ -1,5 +1,6 @@
import os import os
import argparse import argparse
import numpy as np
class Node: class Node:
def __init__(self, name, op, backend, shape): def __init__(self, name, op, backend, shape):
@ -49,17 +50,29 @@ def read_graph(file_path, skip_pattens=[]):
node = Node(name, op, backend, shape) node = Node(name, op, backend, shape)
if do_skip(name): if do_skip(name):
continue continue
nodes[name] = node
source = lines[i + 4].split("[")[1].split("]")[0] source = lines[i + 4].split("[")[1].split("]")[0]
source = list(map(lambda x: x, source.split(", "))) source = list(map(lambda x: x, source.split(", ")))
for pre_node in source: if name.startswith("norm-"):
if do_skip(pre_node): ffn_norm = False
for prev_node in source:
if "ffn_inp" in prev_node:
ffn_norm = True
if ffn_norm:
name = "ffn_" + name
else:
name = "attn_" + name
if name in nodes:
continue
nodes[name] = node
for prev_node in source:
if do_skip(prev_node):
continue continue
if pre_node not in nodes: if prev_node not in nodes:
nodes[pre_node] = Node(pre_node, "", "", []) nodes[prev_node] = Node(prev_node, "", "", [])
edges.append((pre_node, name)) edges.append((prev_node, name))
for prev, next in edges: for prev, next in edges:
nodes[next].in_deg += 1 nodes[next].in_deg += 1
@ -69,28 +82,39 @@ def read_graph(file_path, skip_pattens=[]):
return nodes return nodes
def compute_concur(start, nodes): def travel_in_topology(nodes, show_path=False):
concur = 1 degrees = {name: node.in_deg for name, node in nodes.items()}
order = 0 concur = 0
queue = [(order, start)] orders = {}
queue = []
for name, degree in degrees.items():
orders[name] = 1
if degree == 0:
queue.append(name)
while len(queue) > 0: while len(queue) > 0:
if order != queue[0][0]: cur_node = queue.pop(0)
concur = len(queue)
order = queue[0][0]
cur_order, cur_node = queue.pop(0)
for next_node in nodes[cur_node].next: for next_node in nodes[cur_node].next:
queue.append((cur_order + 1, next_node)) degrees[next_node] -= 1
return concur orders[next_node] = np.max((orders[next_node], orders[cur_node] + 1))
if degrees[next_node] == 0:
queue.append(next_node)
if show_path:
for prev_node in nodes[next_node].prev:
print(f"{prev_node} -> {next_node}")
return concur, orders
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--log_file", required=True, type=str) parser.add_argument("--log_file", required=True, type=str)
parser.add_argument("--start_node", type=str, default=None)
parser.add_argument("--show_path", action=argparse.BooleanOptionalAction)
args = parser.parse_args() args = parser.parse_args()
gf = read_graph(args.log_file, skip_pattens=[".weight"]) gf = read_graph(args.log_file, skip_pattens=[".weight"])
max_concur = 1 concur, orders = travel_in_topology(gf, show_path=args.show_path)
for name, node in gf.items(): print(f"max concurrency: {concur}, max order {np.max(list(orders.values()))}")
if node.in_deg == 0:
concur = compute_concur(name, gf)
print(f"Start node: {name}, Max concurrency: {concur}")