Merge branch 'shangyu' of https://github.com/luffy06/llama.cpp into shangyu
This commit is contained in:
commit
0bd702984c
1 changed files with 53 additions and 21 deletions
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
class Node:
|
||||
def __init__(self, name, op, backend, shape):
|
||||
|
@ -49,17 +50,29 @@ def read_graph(file_path, skip_pattens=[]):
|
|||
node = Node(name, op, backend, shape)
|
||||
if do_skip(name):
|
||||
continue
|
||||
nodes[name] = node
|
||||
|
||||
source = lines[i + 4].split("[")[1].split("]")[0]
|
||||
source = list(map(lambda x: x, source.split(", ")))
|
||||
|
||||
for pre_node in source:
|
||||
if do_skip(pre_node):
|
||||
if name.startswith("norm-"):
|
||||
ffn_norm = False
|
||||
for prev_node in source:
|
||||
if "ffn_inp" in prev_node:
|
||||
ffn_norm = True
|
||||
if ffn_norm:
|
||||
name = "ffn_" + name
|
||||
else:
|
||||
name = "attn_" + name
|
||||
if name in nodes:
|
||||
continue
|
||||
nodes[name] = node
|
||||
|
||||
for prev_node in source:
|
||||
if do_skip(prev_node):
|
||||
continue
|
||||
if pre_node not in nodes:
|
||||
nodes[pre_node] = Node(pre_node, "", "", [])
|
||||
edges.append((pre_node, name))
|
||||
if prev_node not in nodes:
|
||||
nodes[prev_node] = Node(prev_node, "", "", [])
|
||||
edges.append((prev_node, name))
|
||||
|
||||
for prev, next in edges:
|
||||
nodes[next].in_deg += 1
|
||||
|
@ -69,28 +82,47 @@ def read_graph(file_path, skip_pattens=[]):
|
|||
|
||||
return nodes
|
||||
|
||||
def compute_concur(start, nodes):
|
||||
concur = 1
|
||||
order = 0
|
||||
queue = [(order, start)]
|
||||
def travel_in_topology(nodes, show_path=False):
|
||||
degrees = {name: node.in_deg for name, node in nodes.items()}
|
||||
concur = 0
|
||||
orders = {}
|
||||
queue = []
|
||||
|
||||
for name, degree in degrees.items():
|
||||
orders[name] = 1
|
||||
if degree == 0:
|
||||
queue.append(name)
|
||||
|
||||
if show_path:
|
||||
print("Graph:")
|
||||
|
||||
while len(queue) > 0:
|
||||
if order != queue[0][0]:
|
||||
concur = len(queue)
|
||||
order = queue[0][0]
|
||||
cur_order, cur_node = queue.pop(0)
|
||||
cur_node = queue.pop(0)
|
||||
for next_node in nodes[cur_node].next:
|
||||
queue.append((cur_order + 1, next_node))
|
||||
return concur
|
||||
degrees[next_node] -= 1
|
||||
orders[next_node] = np.max((orders[next_node], orders[cur_node] + 1))
|
||||
if degrees[next_node] == 0:
|
||||
queue.append(next_node)
|
||||
if show_path:
|
||||
for prev_node in nodes[next_node].prev:
|
||||
print(f"\t{prev_node} -> {next_node}")
|
||||
|
||||
return orders
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--log_file", required=True, type=str)
|
||||
parser.add_argument("--show_path", action=argparse.BooleanOptionalAction)
|
||||
parser.add_argument("--show_detailed_orders", action=argparse.BooleanOptionalAction)
|
||||
args = parser.parse_args()
|
||||
|
||||
gf = read_graph(args.log_file, skip_pattens=[".weight"])
|
||||
|
||||
max_concur = 1
|
||||
for name, node in gf.items():
|
||||
if node.in_deg == 0:
|
||||
concur = compute_concur(name, gf)
|
||||
print(f"Start node: {name}, Max concurrency: {concur}")
|
||||
orders = travel_in_topology(gf, show_path=args.show_path)
|
||||
order_values = list(orders.values())
|
||||
print(f"Max concurrency: {np.max([order_values.count(x) if x > 3 else 0 for x in np.unique(order_values)])}\nMax order {np.max(order_values)}")
|
||||
if args.show_detailed_orders:
|
||||
print(f"Detailed orders:")
|
||||
for name, order in sorted(list(orders.items()), key=lambda x: x[1]):
|
||||
print(f"\t{name}: {order}")
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue