かまたま日記3

プログラミングメイン、たまに日常

Lambdaオブジェクトの型パラメータを取るのは難しい

TL;DR

  • ラムダオブジェクトの型パラメータを取得するスマートな方法は今の所見つかっていない
  • もし基盤プログラムでそういうことをしたい場合は、ラムダを禁止して、匿名クラスを使う
  • いい方法があったら教えてください

本文

Javaで基盤プログラム的なのを作るとき、ジェネリクスの型パラメータを取得したいことがあります。普通のクラスや匿名クラスの場合は以下のようなリフレクションのコードで取得することができます。

import java.lang.reflect.ParameterizedType;
import java.util.function.Consumer;

public class FooTest {
    public static void main(String[] args) {
        System.out.println(getGenericTypeParam(new Foo()));
        System.out.println(getGenericTypeParam(new Bar()));
        System.out.println(getGenericTypeParam(Baz));
    }

    private static Class<?> getGenericTypeParam(Consumer consumer) {
        ParameterizedType type = (ParameterizedType) consumer.getClass().getGenericInterfaces()[0];
        return (Class) type.getActualTypeArguments()[0];
    }

    private static class Foo implements Consumer<String> {
        @Override
        public void accept(String s) {}
    }

    private static class Bar implements Consumer<Integer> {
        @Override
        public void accept(Integer s) {}
    }

    private static Consumer<Void> Baz = new Consumer<Void>() {
        @Override
        public void accept(Void aVoid) {
        }
    };
}

実行結果

class java.lang.String
class java.lang.Integer
class java.lang.Void

しかし、これがラムダになると、getGenericInterfaces メソッドの結果が ParameterizedType ではなく単なる java.lang.Object のクラス型になり、ClassCastExceptionが発生してしまいます。

getGenericTypeParam((Consumer<Byte>) (b -> {}));

結果

Exception in thread "main" java.lang.ClassCastException: java.lang.Class cannot be cast to java.lang.reflect.ParameterizedType

つまり、ラムダ式で生成された関数オブジェクトからは型パラメータの情報が消えているのです。これをどうにかして取得したいといろいろ模索していたのですが、結局ダメでした。一番惜しかったのはこちらのGistのやり方です。

ラムダ式が定義されているクラスで getDeclaredMethods を使ってメソッド一覧を見ると、そのクラス内で定義されたラムダ式に対応したSyntheticメソッドが生成されています。

import java.lang.reflect.Method;

public class FooTest {
    public static void main(String[] args) {
        Runnable task1 = () -> task2();
        System.out.println(task1.getClass().getName());
        task1.run();

        System.out.println();
        for (Method method : FooTest.class.getDeclaredMethods()) {
            System.out.println(method);
        }
    }

    private static void task2() {
        Runnable task2 = () -> {};
        System.out.println(task2.getClass().getName());
    }
}

上記のプログラムを実行すると以下のような結果が出力されます。

FooTest$$Lambda$1/664223387
FooTest$$Lambda$2/666641942

public static void FooTest.main(java.lang.String[])
private static void FooTest.task2()
private static void FooTest.lambda$task2$1()
private static void FooTest.lambda$main$0()

ラムダ式で生成されたオブジェクトとSyntheticメソッドにはどちらも名前に $1 的なインデックスがついており、それぞれのインデックスが1対1で対応していそうです。そこで、先のGistに習って以下の getGenericTypeParamSmart を追加します。

private static Class<?> getGenericTypeParamSmart(Consumer consumer) {
    String functionClassName = consumer.getClass().getName();
    int lambdaMarkerIndex = functionClassName.indexOf("$$Lambda$");
    if (lambdaMarkerIndex == -1) { // Not a lambda
        return getGenericTypeParam(consumer);
    }

    String declaringClassName = functionClassName.substring(0, lambdaMarkerIndex);
    int lambdaIndex = Integer.parseInt(functionClassName.substring(lambdaMarkerIndex + 9, functionClassName.lastIndexOf('/')));

    Class<?> declaringClass;
    try {
        declaringClass = Class.forName(declaringClassName);
    } catch (ClassNotFoundException e) {
        throw new IllegalStateException("Unable to find lambda's parent class " + declaringClassName);
    }

    for (Method method : declaringClass.getDeclaredMethods()) {
        if (method.isSynthetic()
                && method.getName().startsWith("lambda$")
                && method.getName().endsWith("$" + (lambdaIndex - 1))
                && Modifier.isStatic(method.getModifiers())) {
            return method.getParameterTypes()[0];
        }
    }
    throw new IllegalStateException("Unable to find lambda's implementation method");
}

その上で、以下のコードを実行するとちゃんと型パラメータが取れてそうです

public static void main(String[] args) {
    System.out.println(getGenericTypeParamSmart(new Foo()));
    System.out.println(getGenericTypeParamSmart(new Bar()));
    System.out.println(getGenericTypeParamSmart(Baz));
    System.out.println(getGenericTypeParamSmart((Consumer<Byte>) (b -> {})));
    System.out.println(getGenericTypeParamSmart((Consumer<Long>) (l -> {})));
}
class java.lang.String
class java.lang.Integer
class java.lang.Void
class java.lang.Byte
class java.lang.Long

ただし、Gistにもコメントしましたが、ラムダ内でラムダを生成した場合、例えば、以下のパターンでは失敗します。

public static void main(String[] args) {
    Runnable task = () -> {
        System.out.println(getGenericTypeParamSmart(new Foo()));
        System.out.println(getGenericTypeParamSmart(new Bar()));
        System.out.println(getGenericTypeParamSmart(Baz));
        System.out.println(getGenericTypeParamSmart((Consumer<Byte>) (b -> {})));
        System.out.println(getGenericTypeParamSmart((Consumer<Long>) (b -> {})));
    };
    task.run();
}
class java.lang.String
class java.lang.Integer
class java.lang.Void
class java.lang.Long
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 0

なぜでしょうか? ラムダオブジェクトとFooTestクラスに定義されているSyntheticメソッドを比較してみましょう

public static void main(String[] args) {
    Runnable task = () -> {
        Consumer<Byte> byteConsumer = b -> {};
        Consumer<Long> longConsumer = l -> {};
        System.out.println("byteConsumer: " + byteConsumer.getClass().getName());
        System.out.println("longConsumer: " + longConsumer.getClass().getName());
    };
    System.out.println("task: " + task.getClass().getName());
    task.run();

    System.out.println();
    for (Method method : FooTest.class.getDeclaredMethods()) {
        if (method.isSynthetic()) {
            System.out.println(method.toString());
        }
    }
}

これの実行結果は以下のようになります

task: FooTest$$Lambda$1/664223387
byteConsumer: FooTest$$Lambda$2/1349393271
longConsumer: FooTest$$Lambda$3/159413332

private static void FooTest.lambda$main$2()
private static void FooTest.lambda$null$1(java.lang.Long)
private static void FooTest.lambda$null$0(java.lang.Byte)

つまり:

  • task オブジェクト FooTest$$Lambda$1 に対応するSyntheticメソッドは lambda$main$2
  • byteConsumer オブジェクト FooTest$$Lambda$2 に対応するSyntheticメソッドは lambda$null$0
  • longConsumer オブジェクト FooTest$$Lambda$3 に対応するSyntheticメソッドは lambda$null$1

に、なるわけです。番号の対応がずれてるので、間違ったメソッドを検索してしまっていたわけです。そういう訳で、この方法は使えませんでした。

そのあといろいろ調べてみましたが、型パラメータをちゃんと取得する方法は見つかりませんでした。というわけで、こういうプログラムを書きたいときは今の所はラムダを禁止したほうが良さそうです。

private static Class<?> getGenericTypeParam(Consumer consumer) {
    String functionClassName = consumer.getClass().getName();
    if (functionClassName.contains("$$Lambda$")) {
        throw new UnsupportedOperationException("Lambda is not supported");
    }
    ParameterizedType type = (ParameterizedType) consumer.getClass().getGenericInterfaces()[0];
    return (Class) type.getActualTypeArguments()[0];
}

最終的なコードは以下のような感じになります。

import java.lang.reflect.ParameterizedType;
import java.util.function.Consumer;

public class FooTest {
    public static void main(String[] args) {
        System.out.println(getGenericTypeParam(new Foo()));
        System.out.println(getGenericTypeParam(new Bar()));
        System.out.println(getGenericTypeParam(Baz));
        try {
            System.out.println(getGenericTypeParam((Consumer<Byte>) (b -> {})));
        } catch (UnsupportedOperationException e) {}
    }

    private static Class<?> getGenericTypeParam(Consumer consumer) {
        String functionClassName = consumer.getClass().getName();
    if (functionClassName.contains("$$Lambda$")) {
        throw new UnsupportedOperationException("Lambda is not supported");
    }
        ParameterizedType type = (ParameterizedType) consumer.getClass().getGenericInterfaces()[0];
        return (Class) type.getActualTypeArguments()[0];
    }

    private static class Foo implements Consumer<String> {
        @Override
        public void accept(String s) {}
    }

    private static class Bar implements Consumer<Integer> {
        @Override
        public void accept(Integer s) {}
    }

    private static Consumer<Void> Baz = new Consumer<Void>() {
        @Override
        public void accept(Void aVoid) {}
    };
}

embulk-output-multiを作った

前職の同僚の @mtsmfm さんがつぶやいていたので、勢いで作ってみました。

github.com

使い方

2019/03/11時点の最新版は 0.4.0 です

outputs に複数のoutputの設定をリスト形式で記述するだけです、簡単ですね。

in:
  type: ...
out:
  type: multi
  outputs:
    - type: stdout
    - type: file
      path_prefix: out_file_
      file_ext: csv
      formatter:
        type: csv
    - type: s3
      ...

注意

エラーハンドリングについて

  • 複数プラグインのどれかが transaction (or resume) でエラーになった場合は、安全側に倒してその後のすべてのプラグインopen は実行されずに終わります。
  • すべてのプラグインopen まで到達した場合、outputの処理が順次実行されていきますが、もし途中で一つでも失敗したアウトプットがあった場合、そのタスクを失敗とみなして例外を投げます。

ConfigDiffに関して

現在は各プラグインのConfigDiffを <plugin_type>_<index_in_outputs> のタグを付けてMapで保存しています。例えば、上記の例で言うと最初のstdoutはstdout_0という名前が付きます。つまり一回ConfigDiffを出力したあとにconfig.ymlを書き換えて順番を変えたり違うプラグインに書き換えたりすると、Diffがマージされない、もしくは違うDiffがマージされてしまう可能性があります。ConfigDiffを使う場合は、一回実行した outputs の順番は変えない方が良いでしょう。

内部実装的な苦労話

通常のEmbulkのOutput pluginの実行フローは以下のようになっています.

  • transaction メソッドが呼ばれる。このメソッド内で下準備 (認証やパラメータの検証) をする。
  • OKなら渡されてきた OutputPlugin.Control オブジェクトの run メソッド (以下 コールバック) を呼びExecutor側に処理を移譲する
  • Execotorがタスク分割などを実施し、タスクごとに open メソッドが呼ばれる。このメソッドで返却した TransactionalPageOutput が送られてきたデータを処理していく。
    • タスクの処理が終わると commit メソッドが呼ばれタスクごとの TaskReport を作成し、返す
    • 失敗したタスクは abort が呼ばれる
  • 全部のタスクが終わったあと cleanup メソッドが実行され、後処理を行う
  • transactionメソッドは、コールバックから返却されたTask Reportsを元に ConfigDiff を作成し、返す

上記の流れを、multiプラグイントランザクション上で複数のプラグインに対してエミュレートする必要があります。そのため以下のような処理の流れを作りました。フローがとても複雑でマルチスレッドを駆使する必要があったので、難しかったですけど楽しかったですw

  • マルチスレッドで各pluginのtransactionを実施しつつ、ダミーのコールバックを渡す
  • すべてのプラグインのtransactionの検証が終わったあと (= ダミーのコールバックが呼ばれたあと) に元のコールバックを呼ぶ
    • ダミーは、オリジナルのコールバックが終わるまでは各プラグインのtransactionをブロックする必要がある
  • 元のコールバックが呼ばれると、multiプラグインのopenが呼ばれる。ここで各プラグインのTransactionalPageOutputを作成し、それぞれのPageOutputに処理(add, finish, closeなど)を移譲する。
    • ただし Page オブジェクトは使い回せないので、各プラグインごとにコピーする
    • どれかのプラグインが途中で失敗してた場合はmultiのタスク自体は失敗とみなして例外を投げる、その後、全プラグインのabortが呼ばれる
  • 元のcleanupが実行される、保持しておいた個々のTaskReportを復元して、各プラグインのcleanupを実行する
  • すべてのタスク終わったあとに。ダミーのコールバックのブロックを解除する。TaskReport内に個々プラグインのTaskReportが保持されているので、復元して個々のプラグインに渡す。
  • 返したTaskReportを元に個々のプラグインのtransactionがConfigDiffを返す。それを集めてmultiプラグインのConfigDiffとして返す

GradleプラグインをGradle community portalにアップロードした

今まで、自作の2つのGradleプラグインGitHub上のオレオレMavenリポジトリから落とすようにしていたんですが、 buildscript でそのリポジトリを指定しないと行けなかったり、記述がちょっとだけ面倒だったので、Gradle community portalにアップしてPlugins DSLで書けるようにしました。

アップロードの方法は簡単で、基本的にこちらの手順に従うだけです

plugins.gradle.org

  • ポータルのアカウントを作る
  • APIキーを作成
  • Gradle plugin publishing pluginを使って諸々の設定をbuild.gradleに書く (ref)
  • publishPlugin タスクを実行

初回のみapprovalが必要でちょっと時間がかかりますが、自分の場合は半日くらいで承認されました。

JUnit 5 入門

そろそろ使ってみるかということで入門してみました。

JUnit Jupiter

こちらにも書かれてますが、JUnit 5は複数のサブプロジェクトからなり、JUnit 5でテストを書いたり拡張機能を書くためのクラスはJUnit Jupiterというプロジェクトにあります。なので、テストを書く場合は org.junit.jupiter 配下の各種ライブラリーをインポートして使うことになります。

Gradle から使う

最低限以下の記述が必要になります。 (Gradle 4.6以上が必要です)

dependencies {
    testCompile('org.junit.jupiter:junit-jupiter-api:5.3.2')
    testRuntime('org.junit.jupiter:junit-jupiter-engine:5.3.2')
}

test {
    useJUnitPlatform()
}
  • junit-jupiter-apiJUnit Jupiter のテストを書くのに必要なクラス、アノテーション群があります。 testCompile で指定します。
  • junit-jupiter-engineJUnit Jupiter のテストを実行するための TestEngine 実装です。testRuntime で指定します。
  • testタスク内のコンフィグレーションで useJUnitPlatform を指定することで、JUnit 5のプラットフォームを使うように宣言します。

テストを書く、実行する

基本的な書き方はJUnit 4までと同じで、テストしたいメソッドに @Testアノテーションを追加します。アサーションorg.junit.jupiter.api.AssertionsクラスにJUnit 3まででお馴染みの assertEquals などの基本的なアサーションメソッドがあるので、それを使います。JUnit 4時代の assertThat や、AssertJなどの別のアサーションライブラリを使いたい場合は、別途HamcrestやAssertJをインストールして使うことができます。

package com.example.project;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

class SampleTests {
    @Test
    void onePlusOneEqualsTwo() {
        assertEquals(2, 1 + 1);
    }
}

実行はGradleのテストタスクで実行します。

$ ./gradlew test

> Task :test

com.example.project.SampleTests > onePlusOneEqualsTwo() PASSED

BUILD SUCCESSFUL in 1s
3 actionable tasks: 2 executed, 1 up-to-date

ちょっと高度な使い方集

BeforeとかAfterとか

@Before @After@BeforeEach @AfterEach に、 @BeforeClass @AfterClass@BeforeAll, @AfterAll に置き換えられました。

package com.example.project;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class DBTests {
    @BeforeAll
    static void initializeDB() {
        System.out.println("Initializing Database...");
    }

    @AfterAll
    static void deleteDB() {
        System.out.println("Deleting Database...");
    }

    @BeforeEach
    void insertData() {
        System.out.println("Inserting test data...");
    }

    @AfterEach
    void clearData() {
        System.out.println("Clearing test data...");
    }

    @Test
    void testWithDB() {
        System.out.println("Testing...");
    }
}

Parameterized Test

junit-jupiter-params のライブラリをインストールした上で @ParameterizedTest をテストメソッドに追加します。パラメタのソースは、簡易的には @CsvSourceCSV文字列で指定できます。もうちょっと高度にやりたい場合は @ArgumentsSource アノテーションを使うことで独自の引数のProviderを指定することができます。

dependencies {
    ....
    testCompile('org.junit.jupiter:junit-jupiter-params:5.3.2')
    ....
}
package com.example.project;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;

import static org.junit.jupiter.api.Assertions.assertEquals;

class ParameterizedTests {
    @ParameterizedTest(name = "{0} + {1} = {2}")
    @CsvSource({
            "0,    1,   1",
            "1,    2,   3",
            "49,  51, 100",
            "1,  100, 101"
    })
    void testsForPlus(int first, int second, int expected) {
        assertEquals(expected, first + second, first + " + " + second + " should equal " + expected);
    }
}

JUnit 4のテストを実行する

junit-vintage-enginetestRuntime でインストールします。これはJUnit 4以前のテストを実行するためのTestEngineの実装です。

dependencies {
    ...
    testCompile "junit:junit:4.12"
    testRuntime "org.junit.vintage:junit-vintage-engine:5.3.2"
    ...
}
package com.example.project;

import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;

import org.junit.Test;

public class JUnit4Tests {
    @Test
    public void test() {
        assertThat(1 + 2, is(3));
    }
}

拡張機能

JUnit 4の @Rule のような拡張機能は任意の Extension classを実装することで実現できます。 詳しくは この辺 参照。

package com.example.project;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.support.AnnotationSupport;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.Optional;

class ExtensionTests {
    @Test
    @MyExtension("FOO")
    void extensionTest() {
        System.out.println("Running a test...");
    }

    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @ExtendWith(MyExtensionImpl.class)
    private @interface MyExtension {
        String value();
    }

    private static class MyExtensionImpl implements BeforeEachCallback {
        @Override
        public void beforeEach(ExtensionContext context) {
            final Optional<MyExtension> annotation = AnnotationSupport.findAnnotation(context.getTestMethod(), MyExtension.class);
            System.out.println(String.format("Running my extension with %s...", annotation.get().value()));
        }
    }
}

参考

gradle-embulk-plugin v0.3.0 リリース

Release v0.3.0 · kamatama41/gradle-embulk-plugin · GitHub

gem, gemPush のタスクの内容を最新のEmbulkのものに追従しました。詳しい使い方は以下の記事を参照ください。

kamatama41.hatenablog.com

embulk-filter-hash v0.5.0 リリース

イシューで希望をくれたHMACのハッシュ化に対応しました。以下のような感じで algorithm にHMACのアルゴリズムを指定した上で secret_key秘密鍵を入れると使えます

filters:
  - type: hash
    columns:
    - { name: username }
    - { name: phone_number, algorithm: HmacSHA256, secret_key: passw0rd }

JavaでRubyのeach_sliceがしたい

each_sliceというのは配列を指定した要素数の配列に分ける処理です。リストの中身をn件ごとに処理するときに便利です。 Javaには同様の処理が(たぶん)標準APIには無いので、こんな感じで行けそうです。

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class Test
{
    public static void main(String[] args)
    {
        // [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]
        System.out.println(slice(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 3));
    }

    private static <T> List<List<T>> slice(List<T> list, int n)
    {
        final int resultSize = (int) Math.ceil((double) list.size() / n);
        return IntStream.range(0, resultSize)
                .mapToObj(i -> list.subList(n * i, Math.min(list.size(), n * (i + 1))))
                .collect(Collectors.toList());
    }
}

参考: Is there a way to do the Ruby each_slice in Java 8? - Stack Overflow