Задать вопрос
@keddad
Ученик

Как максимально ускорить данный код на Python?

Есть следующий код, который решает эту задачу
from array import array

n, m = map(int, input().split())

parent, weight, rank = array('i', [-1 for _ in range(n)]), array('I', [0 for _ in range(n)]), array('I', [1 for _ in
                                                                                                          range(n)])


def find_set(v: int) -> int:
    if parent[v] == -1:
        parent[v] = v
        return v
    if v == parent[v]:
        return v
    parent[v] = find_set(parent[v])
    weight[parent[v]] += weight[v]
    weight[v] = 0
    return parent[v]


def union_sets(a: int, b: int, cost: int) -> None:
    a = find_set(a)
    b = find_set(b)
    if a != b:
        if rank[a] < rank[b]:
            a, b = b, a
        parent[b] = a
        weight[a] += cost
        weight[a] += weight[b]
        weight[b] = 0
        if rank[a] == rank[b]:
            rank[a] += 1
    else:
        weight[a] += cost


with open("input.txt", "r") as inp:
    with open("output.txt", "w") as out:
        inp.__next__()
        for line in inp:
            st = line.split()
            if len(st) != 4:
                out.write(str(weight[find_set(int(st[1]) - 1)]) + "\n")
            else:
                union_sets(int(st[1]) - 1, int(st[2]) - 1, int(st[3]))


К сожалению, он несколько не укладывается в временные рамки. Теоретически можно было бы просто написать тот же алгоритм на C++, но мы не ищем легких путей! Какие еще оптимизации можно применить к коду выше, что бы уменьшить время работы при больших обьемах данных?
  • Вопрос задан
  • 670 просмотров
Подписаться 1 Простой 1 комментарий
Решения вопроса 1
@keddad Автор вопроса
Ученик
Окей, в итоговой версии я избавился от рекурсии, типизации и Arrayев. Основной профит получил от рекурсии, конечно. Этого не хватило для решения задачи, но код заметно ускорился.
n, m = map(int, input().split())

parent, weight, rank = [-1 for _ in range(n)], [0 for _ in range(n)], [1 for _ in range(n)]


def find_set(v):
    while parent[v] != -1 and parent[v] != v:
        parent[v] = parent[parent[v]]
        weight[parent[v]] += weight[v]
        weight[v] = 0
        v = parent[v]
    if parent[v] == -1:
        parent[v] = v
    return v


def union_sets(a, b, cost):
    a = find_set(a)
    b = find_set(b)
    if a != b:
        if rank[a] < rank[b]:
            a, b = b, a
        parent[b] = a
        weight[a] += cost
        weight[a] += weight[b]
        weight[b] = 0
        if rank[a] == rank[b]:
            rank[a] += 1
    else:
        weight[a] += cost


def main():
    with open("input.txt", "r") as inp:
        with open("output.txt", "w") as out:
            inp.__next__()
            for line in inp:
                st = line.split()
                if len(st) != 4:
                    out.write(str(weight[find_set(int(st[1]) - 1)]) + "\n")
                else:
                    union_sets(int(st[1]) - 1, int(st[2]) - 1, int(st[3]))


main()
Ответ написан
Пригласить эксперта
Ответы на вопрос 2
@deliro
1. Конструкция [-1 for _ in range(n)] уже создаёт список. Дальше этот список просто выкидывается и генерируется array.array. Итого в одной строчке сразу генерируются 6 потенциально огромных коллекций. Либо можно переделать на (-1 for _ in range(n)), либо отказаться от array.array, его преимущества здесь сомнительны:

In [3]: a = array("I", range(10000))                                                                                    

In [4]: b = list(range(10000))                                                                                          

In [5]: %timeit sum(a)                                                                                                  
206 µs ± 6.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [6]: %timeit sum(b)                                                                                                  
69.3 µs ± 367 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [7]: %timeit a[7777]                                                                                                 
49.5 ns ± 0.564 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

In [8]: %timeit b[7777]                                                                                                 
33.6 ns ± 0.411 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)


2. Убрать рекурсию. Совсем.
3. Удалить typing
Ответ написан
sgjurano
@sgjurano
Разработчик
Вы уверены, что у вас корректно выполняется сжатие путей?

Ну и от рекурсии лучше уйти, в питоне довольно дорогие вызовы функций.
Ответ написан
Комментировать
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Похожие вопросы