@relgames
Java Developer

Почему Fork/Join из JDK7 медленно работает?

Пишу простую программу для генерации ряда Фарея


В начале написал простую рекурсию, ряд 500-го порядка генерирует за 100мс
Скрытый текст
public class NestedIntervals {
    private final int base;

    private NestedIntervals(int base) {
        this.base = base;
    }

    public static NestedIntervals base(int base) {
        return new NestedIntervals(base);
    }

    private List<Fraction> getFarey(Fraction left, Fraction right) {
        Fraction mediant = new Fraction(left.getNumerator()+right.getNumerator(), left.getDenominator()+right.getDenominator());
        if (mediant.getDenominator()>base) {
            return Collections.emptyList();
        }
        List<Fraction> result = new LinkedList<>();
        result.addAll(getFarey(left, mediant));
        result.add(mediant);
        result.addAll(getFarey(mediant, right));
        return result;
    }

    public List<Fraction> getFarey() {
        List<Fraction> result = new LinkedList<>();

        Fraction left = new Fraction(0, 1);
        Fraction right = new Fraction(1, 1);

        result.add(left);
        result.addAll(getFarey(left, right));
        result.add(right);

        return result;
    }

    public static void main(String[] args) {
        long time = System.currentTimeMillis();
        List<Fraction> farey = NestedIntervals.base(500).getFarey();

        int max = 0;
        for (Fraction f: farey) {
            if (f.getDenominator()>max) {
                max = f.getDenominator();
            }
        }

        System.out.printf("Total %dms", System.currentTimeMillis()-time);
    }
}



На рядах бОльшего порядка работает медленно. Задумал переписать на Fork/Join
Скрытый текст
public class FareyList extends RecursiveTask<List<Fraction>>{
    private final static int base = 500;

    private final Fraction left;
    private final Fraction right;

    private FareyList(Fraction left, Fraction right) {
        this.left = left;
        this.right = right;
    }

    @Override
    protected List<Fraction> compute() {
        Fraction mediant = new Fraction(left.getNumerator()+right.getNumerator(), left.getDenominator()+right.getDenominator());
        if (mediant.getDenominator()>base) {
            return Collections.emptyList();
        }

        FareyList leftList = new FareyList(left, mediant);
        FareyList rightList = new FareyList(mediant, right);
        leftList.fork();
        rightList.fork();

        List<Fraction> result = new LinkedList<>();
        result.addAll(leftList.join());
        result.add(mediant);
        result.addAll(rightList.join());
        return result;
    }

    public static List<Fraction> getFareyList() {
        List<Fraction> result = new LinkedList<>();

        Fraction left = new Fraction(0, 1);
        Fraction right = new Fraction(1, 1);

        FareyList task = new FareyList(left, right);

        new ForkJoinPool().invoke(task);

        result.add(left);
        result.addAll(task.join());
        result.add(right);

        return result;
    }

    public static void main(String[] args) {
        long time = System.currentTimeMillis();
        List<Fraction> farey = FareyList.getFareyList();

        int max = 0;
        for (Fraction f: farey) {
            if (f.getDenominator()>max) {
                max = f.getDenominator();
            }
        }

        System.out.printf("Total %dms", System.currentTimeMillis()-time);
    }

}



Этот код выполняется за 12 секунд. 4х ядерный Xeon, должно быть быстрее, чем рекурсивная версия, однако на практике все наоборот. Что я делаю не так? Подходит ли вообще Fork/Join для подобной задачи?
  • Вопрос задан
  • 3958 просмотров
Решения вопроса 1
TheShade
@TheShade
В-нулевых, вы померили стартап, компиляцию и поднятие потоков в FJP, с чем вас и поздравляю.

Во-первых, сделайте уход в секвенциальную версию, начиная с некоторого маленького threshold'а. Иначе накладные расходы на распиливание и менеджмент задач всё сожрут.

Во-вторых, способа джойнить хуже, чем «t1.fork(); t2.fork(); t1.join(); t2.join()» не придумаешь. Лучший из известных способов делает «t1.fork(); t2.compute(); t1.join()».
Ответ написан
Пригласить эксперта
Ответы на вопрос 9
ivnik
@ivnik
1) Сделайте 2 запуска вычисления, и время вычисляйте во 2-м запуске (время должно уменьшится за счёт отсутствия загрузки классов)
2) При разделении задачи не доходите до самого конца, сделайте некоторое количество fork-ов, а затем просто вызовите рекурсивную функцию
Ответ написан
rfq
@rfq
Программист
Я переписал ваш код, избегая излишнего создания объектов и перекладывания объектов из коллекции в коллекцию. Распаралелленые варианты медленнее рекурсивного, но не драматично — всего в несколько раз.
Как правильно уже указывали, медленнее из-за накладных расходов на ведение подзадач. Если суметь укрупнить подзадачи, то получим реальный выигрыш.
Ответ написан
Foror
@Foror
Графоман
Если надо быстрее, попробуйте такую штуку habrahabr.ru/post/149552
Ответ написан
Комментировать
@1nd1go
Я не разбираюсь в смысле задачи, поэтому по коду ничего сказать не могу (смущает меня некоторое рекурсивное создание ForkJoinPoolов, которое явно выходит за количество равное 1ому), но есть общее правило. Если у вас вычисления, которые нужно производить на процессоре, то имеет смысл пул задавать равным количеству доступных для JVM ядер. Поэтому, я бы создал 1 ForkJoinPool (дефолтный конструктор делает параллелизм равным количеству ядер), и туда сабмитил таски. а не создавал каждый раз новый пул на каждое дробление (или что там у вас делается).
Ответ написан
@relgames Автор вопроса
Java Developer
Забыл спросить, секвенциальную версия — это что?
Ответ написан
Foror
@Foror
Графоман
А вообще, если хотите добиться оптимальной скорости, то нужно вручную управлять нитями. Создаете фиксированный пул нитей (например, по одной на ядро), а дальше в этот пул закидываете ваши расчеты, которые выстраиваются в очередь и последовательно обрабатываются. Давно уже с этим не работал, так что не смог вам прям здесь пример накидать…
Ответ написан
Комментировать
Foror
@Foror
Графоман
И еще попробуйте поиграться с new ForkJoinPool(nThreads), сделайте nThreads = 2, например.
Ответ написан
lebedi
@lebedi
Переписал не используя рекурсию, так как стек не самая сильная сторона java, а также добавил версию с ExecutorService
Мои результаты тестов на BASE=5000
Total 14306ms (рекурсивная реализация)
n Total 1725ms (нерекурсивная реализация)
m Total 1241ms (многопоточная нерекурсивная реализация)
Исходный код
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class NestedIntervals {
    private static ExecutorService executor;
    private final int base;

    private NestedIntervals(int base) {
        this.base = base;
    }

    public static NestedIntervals base(int base) {
        return new NestedIntervals(base);
    }

    public static class Fraction {
        private int denominator;
        private int numerator;

        public Fraction(int numerator, int denominator) {
            super();
            this.numerator = numerator;
            this.denominator = denominator;
        }

        public int getDenominator() {
            return denominator;
        }

        public void setDenominator(int denominator) {
            this.denominator = denominator;
        }

        public int getNumerator() {
            return numerator;
        }

        public void setNumerator(int numerator) {
            this.numerator = numerator;
        }

        @Override
        public String toString() {
            return numerator + "/" + denominator;
        }

    }

    private List<Fraction> getFarey(Fraction left, Fraction right) {
        Fraction mediant = new Fraction(left.getNumerator()
                + right.getNumerator(), left.getDenominator()
                + right.getDenominator());
        if (mediant.getDenominator() > base) {
            return Collections.emptyList();
        }
        List<Fraction> result = new LinkedList<Fraction>();
        result.addAll(getFarey(left, mediant));
        result.add(mediant);
        result.addAll(getFarey(mediant, right));
        return result;
    }

    public List<Fraction> getFarey() {
        List<Fraction> result = new LinkedList<Fraction>();

        Fraction left = new Fraction(0, 1);
        Fraction right = new Fraction(1, 1);

        result.add(left);
        result.addAll(getFarey(left, right));
        result.add(right);

        return result;
    }

    public List<Fraction> getFareyNonRecurcive() {
        LinkedList<Fraction> result = new LinkedList<Fraction>();
        Fraction left = new Fraction(0, 1);
        Fraction right = new Fraction(1, 1);
        Fraction mediant = null;
        result.add(left);
        result.add(right);
        ListIterator<Fraction> iterator = result.listIterator();
        left = iterator.next();
        while (iterator.hasNext()) {
            right = iterator.next();
            int denominator = left.getDenominator() + right.getDenominator();
            if (denominator <= base) {
                mediant = new Fraction(left.getNumerator()
                        + right.getNumerator(), denominator);
                iterator.previous();
                iterator.add(mediant);
                iterator.previous();
            } else {
                left = right;
            }
        }
        return result;
    }

    public List<Fraction> getFareyNonRecurciveMultiThread(int threads) {
        List<Fraction> result = null;
        if (threads <= 1) {
            result = getFareyNonRecurcive();
        } else {

            LinkedList<Fraction> ret = new LinkedList<Fraction>();
            List<Future<LinkedList<Fraction>>> futureList = new
                    ArrayList<Future<LinkedList<Fraction>>>();
            // fill params
            LinkedList<Fraction> params = fillParams(threads);
            ListIterator<Fraction> listIterator = params.listIterator();
            Fraction first = listIterator.next();
            while (listIterator.hasNext()) {
                Fraction second = (Fraction) listIterator.next();
                Future<LinkedList<Fraction>> future = executor
                        .submit(new FareyCallable(first, second, base));
                futureList.add(future);
                first = second;
            }
            try {
                for (int i = 0; i < futureList.size(); i++) {
                    if (i == 0) {
                        ret.addAll(futureList.get(i).get());
                    } else {
                        LinkedList<Fraction> linkedList = futureList.get(i)
                                .get();
                        linkedList.pollFirst();
                        ret.addAll(linkedList);
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }

            result = ret;
        }
        return result;
    }

    private LinkedList<Fraction> fillParams(int threads) {
        LinkedList<Fraction> params = new LinkedList<Fraction>();
        params.add(new Fraction(0, 1));
        params.add(new Fraction(1, 1));
        Fraction left = null;
        Fraction right = null;
        Fraction mediant = null;
        for (int j = 0; j < base; j++) {
            ListIterator<Fraction> listIterator = params.listIterator();
            left = listIterator.next();

            while (listIterator.hasNext()) {
                if (params.size() - 1 == threads) {
                    return params;
                }
                right = listIterator.next();
                int denominator = left.getDenominator()
                        + right.getDenominator();
                if (denominator <= base) {
                    mediant = new Fraction(left.getNumerator()
                            + right.getNumerator(), denominator);
                    listIterator.previous();
                    listIterator.add(mediant);
                    left = listIterator.next();
                } else {
                    left = right;
                }

            }
        }
        throw new RuntimeException();
    }

    public static class FareyCallable implements Callable<LinkedList<Fraction>> {
        private Fraction first;
        private Fraction last;
        private int base;

        public FareyCallable(Fraction first, Fraction last, int base) {
            super();
            this.first = first;
            this.last = last;
            this.base = base;
        }

        @Override
        public LinkedList<Fraction> call() throws Exception {
            LinkedList<Fraction> result = new LinkedList<Fraction>();
            Fraction left = first;
            Fraction right = last;
            Fraction mediant = null;
            result.add(left);
            result.add(right);
            ListIterator<Fraction> iterator = result.listIterator();
            left = iterator.next();
            while (iterator.hasNext()) {
                right = iterator.next();
                int denominator = left.getDenominator()
                        + right.getDenominator();
                if (denominator <= base) {
                    mediant = new Fraction(left.getNumerator()
                            + right.getNumerator(), denominator);
                    iterator.previous();
                    iterator.add(mediant);
                    iterator.previous();
                } else {
                    left = right;
                }
            }
            return result;
        }

    }

    public static void main(String[] args) {
        int threads = 20;
        int base = 5000;
        executor = Executors.newFixedThreadPool(threads);

        for (int i = 0; i < 1; i++) {
            long time = System.currentTimeMillis();
            List<Fraction> farey = NestedIntervals.base(base).getFarey();


            System.out.printf("Total %dms ", System.currentTimeMillis() - time);
            System.out.println(farey.size());

            farey = null;
            System.gc();
            time = System.currentTimeMillis();
            farey = NestedIntervals.base(base).getFareyNonRecurcive();


            System.out.printf("n Total %dms ", System.currentTimeMillis()
                    - time);
            System.out.println(farey.size());

            farey = null;
            System.gc();
            time = System.currentTimeMillis();
            farey = NestedIntervals.base(base).getFareyNonRecurciveMultiThread(threads);

            System.out.printf("m Total %dms ",
                    System.currentTimeMillis() - time);
            System.out.println(farey.size());
            farey = null;
            System.gc();
        }
        executor.shutdownNow();
    }
}


Ответ написан
@Terran37
Программист
Привет. Можно прочитать вот тут ещё один ://javarules.ru/java-8-parallel/ про создание fork/join
Ответ написан
Комментировать
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Войти через центр авторизации
Похожие вопросы