sedコマンドでディレクトリ内の全ファイルをテキスト全置換するには

単一ファイルを置換する場合

テキストファイル内を全置換するにはsedコマンドでできます。

$ sed -i 's/BEFORE/AFTER/g' FILENAME
  • s/BEFORE/AFTER/gsの意味: BEFOREAFTERに置換するというsedのコマンド。sedコマンドは置換だけではなく、行抽出や削除などいろんな処理ができる。
  • s/BEFORE/AFTER/ggの意味: BEFOREにマッチした部分をすべて置換する。gがないと、各行で1つ目のみを置換する。1つ目のみでいいケースは少ないから、置換のときはなにも考えずにgを付けるぐらいでもよい。
  • -i: 全置換してファイルを書き換える。-iがないと、置換結果を標準出力するのみで、ファイル書き換えはしない。

$ cat test.txt 
abcdefabc
abc

$ sed 's/b/B/g' test.txt # 全置換
aBcdefaBc
aBc

$ sed 's/b/B/' test.txt  # 各行1つ目だけ置換
aBcdefabc
aBc

バックアップを残したい場合

-i-i.bak のように拡張子を付けると、ファイル名の最後に.bak を付けた名前で置換前のバックアップを取ってくれます。

ディレクトリ内の全ファイルを置換したい場合

findコマンドとxargsコマンドを併用

カレントディレクトリ内のすべてのファイルを再帰的に探索して、全ファイルを置換するには、findコマンドとxargsコマンドをsedに組み合わせればできます。

$ find . -type f | xargs sed -i 's/BEFORE/AFTER/g'

findコマンドで見つけたファイルをxargsコマンドを使ってsedコマンドに渡します。

findコマンドでファイルを絞ることもできます。

$ find . -type f -name "*.txt" | xargs sed -i 's/BEFORE/AFTER/g'

ただし、この方法は、findコマンドで見つけたファイルをsedコマンドで保存しなおしますので、置換対象の文字列がなかったとしても、タイムスタンプがすべて現在時刻に置き換わります。それが気持ち悪い場合はgrepも併用するとよいです。

grepコマンドも併用

grepコマンドを併用することで、置換対象の文字列があるファイルのみをsedコマンドに渡します。

$ find . -type f -name "*.txt" | grep -rl BEFORE | xargs sed -i 's/BEFORE/AFTER/g'

grepを使う場合で、findでファイルを絞る必要がない場合はfindは不要です。

$ grep -rl BEFORE . | xargs sed -i 's/BEFORE/AFTER/g'

全置換が怖いからバックアップを残したい

sedのバックアップオプション

sedコマンドでバックアップ拡張子を指定すれば、sedコマンドに渡された全ファイルのバックアップが残されます。

$ find . -type f -name "*.txt" | grep -rl BEFORE | xargs sed -i.bak 's/BEFORE/AFTER/g'

この方法は、大量のバックアップファイルが生成されてしまって、後片付けが面倒という問題があります。置換結果の差分を見るのも面倒です。

ディレクトリごとバックアップ

全置換で大量のファイルが更新されるのがちょっと怖いという場合は、置換前にディレクトリごとコピーしてしまって、置換後にディレクトリごとdiffコマンドで見ることが私は多いです。

$ cp -rvp . ../backup

$ find . -type f -name "*.txt" | grep -rl BEFORE | xargs sed -i 's/BEFORE/AFTER/g'

$ diff -r ../backup .

gitを使っていればバックアップ不要

git管理下にあれば、バックアップとらずに git diff で置換結果を見て、もとにもどしたければ git checkout すればよいです。

Google Cloud DataflowをPythonで動かしてみる

前回のGoogle Cloud DataflowをJavaで動かしてみるに続き、次はPythonで試しました。

前提

Pythonライブラリインストール

pip install apache-beam[gcp] というコマンドはbashではそのままでいいですが、zshでは私の環境ではシングルクオートで囲む必要がありました。

$ pip install 'apache-beam[gcp]'

ソースコード

Pythonの1ファイルのみです。

import argparse

from past.builtins import unicode

import apache_beam as beam
from apache_beam.io import ReadFromText
from apache_beam.io import WriteToText
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

class ExtractWordsFn(beam.DoFn):
  def process(self, element):
    return element.split(" ")

def run(argv=None, save_main_session=True):
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--inputPath',
      dest='inputPath',
      required=True,
      help='Path of the file to read from')
  parser.add_argument(
      '--outputPath',
      dest='outputPath',
      required=True,
      help='Path of the file to write to')
  known_args, pipeline_args = parser.parse_known_args(argv)

  pipeline_options = PipelineOptions(pipeline_args)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

  with beam.Pipeline(options=pipeline_options) as p:
    def format_result(word, count):
      return '%s: %d' % (word, count)

    ( p |
      'ReadLines' >> ReadFromText(known_args.inputPath) |
      'Split' >> (beam.ParDo(ExtractWordsFn()).with_output_types(unicode)) |
      (beam.ParDo(ExtractWordsFn()).with_output_types(unicode)) |
      'PairWIthOne' >> beam.Map(lambda x: (x, 1)) |
      'GroupAndSum' >> beam.CombinePerKey(sum) |
      'Format' >> beam.MapTuple(format_result) |
      'WriteCounts' >> WriteToText(known_args.outputPath)
    )

if __name__ == '__main__':
  run()

実行方法

$ python mysample.py --project=PROJECT_ID --region=us-central1 --runner=DataflowRunner --temp_location=TEMP_GS_PATH --inputPath=gs://apache-beam-samples/shakespeare/kinglear.txt --outputPath=OUTPUT_GS_PATH

上記コマンドの中、以下の3箇所は自分の環境に合わせて書き換えます。

  • PROJECT_ID: GCPのプロジェクトID、
  • TEMP_GS_PATH: Dataflowの実行に必要なファイルを置くGCSのパス。このパスの中に自動で必要なファイルがアップロードされる。例: gs://sample-bucket/dataflow/temp/
  • OUTPUT_GS_PATH: 出力先GCSパス。 例: gs://sample-bucket/dataflow/output と指定すると gs://sample-bucket/dataflow/output-00000-of-00003 のような名前のファイルが作成される

コマンドの中で入力として指定している gs://apache-beam-samples/shakespeare/kinglear.txt にはApacheが用意しているサンプル入力ファイルがあり、誰でも読み込み可能です。

このコマンドを実行すると、4分〜5分で完了します。GCPコンソールでDataflowのジョブを見ると次のように表示されています。

f:id:suzuki-navi:20201009233622p:plain

JavaPythonの違い

  • Dataflowの処理時間が違う
    • Javaは3分前後
    • Pythonは4分から5分程度かかる
  • Javaでの --stagingLocation で指定するGoogle Cloud Storageの中身とPythonでの --temp_location で指定するGoogle Cloud Storageの中身は、どちらも自動で実行時に必要なファイルが置かれるが、構成が違う
    • Java--stagingLocation で指定のディレクトリ直下に大量のjarファイルが保存される
    • Python--temp_location で指定のディレクトリ内に日時を含むテンポラリのディレクトリが作られ、その中になにかが保存される。ファイル数もJavaよりずっと少ない

Ubuntu 20.04にC#をインストールしてAWS SDKを使ってみる

C#をほとんど触ったことがないので、手元のUbuntu 20.04に入れてHelloWorldを書きました。そしてC#AWS SDKをインストールしてAWSAPIにアクセスしてみました。

環境はUbuntu 20.04です。

C#のインストール

手順はMicrosoftのサイトに書いてあるとおりです。

$ wget https://packages.microsoft.com/config/ubuntu/20.04/packages-microsoft-prod.deb -O packages-microsoft-prod.deb

$ sudo dpkg -i packages-microsoft-prod.deb

$ sudo apt-get update

$ sudo apt-get install -y apt-transport-https dotnet-sdk-3.1
$ dotnet --version
3.1.402

C#のHelloWorld

CLIプログラムのテンプレート作成と実行は dotnet コマンドでできるようです。空ディレクトリを作成して、その中でテンプレート作成します。

$ mkdir sample
$ cd sample
$ dotnet new console

これを実行すると、以下の通りファイル一式がカレントディレクトリに作成されます。

$ tree
.
├── Program.cs
├── obj
│   ├── project.assets.json
│   ├── project.nuget.cache
│   ├── sample.csproj.nuget.dgspec.json
│   ├── sample.csproj.nuget.g.props
│   └── sample.csproj.nuget.g.targets
└── sample.csproj

1 directory, 7 files

$ cat Program.cs 
using System;

namespace sample
{
    class Program
    {
        static void Main(string[] args)
        {
            Console.WriteLine("Hello World!");
        }
    }
}

ソースコードはすでにHelloWorldになってます。

実行は以下のコマンドです。

$ dotnet run
Hello World!

実行するとbinobjディレクトリにファイルが生成されます。

$ tree
.
├── Program.cs
├── bin
│   └── Debug
│       └── netcoreapp3.1
│           ├── sample
│           ├── sample.deps.json
│           ├── sample.dll
│           ├── sample.pdb
│           ├── sample.runtimeconfig.dev.json
│           └── sample.runtimeconfig.json
├── obj
│   ├── Debug
│   │   └── netcoreapp3.1
│   │       ├── sample
│   │       ├── sample.AssemblyInfo.cs
│   │       ├── sample.AssemblyInfoInputs.cache
│   │       ├── sample.assets.cache
│   │       ├── sample.csproj.CoreCompileInputs.cache
│   │       ├── sample.csproj.FileListAbsolute.txt
│   │       ├── sample.csprojAssemblyReference.cache
│   │       ├── sample.dll
│   │       ├── sample.genruntimeconfig.cache
│   │       └── sample.pdb
│   ├── project.assets.json
│   ├── project.nuget.cache
│   ├── sample.csproj.nuget.dgspec.json
│   ├── sample.csproj.nuget.g.props
│   └── sample.csproj.nuget.g.targets
└── sample.csproj

6 directories, 23 files

C# AWS SDKのインストール

AWSのサービスごとにC#のパッケージがあるようです。

NuGet Gallery | Packages matching id:AWSSDK owner:awsdotnet

今回はS3のパッケージをインストールしてみます。

$ dotnet add package AWSSDK.S3

このコマンドを実行すると、sample.csproj にパッケージの情報が書かれます。パッケージ自体は ~/.nuget/packages/ にダウンロードされるようです。

C#からS3のAPIにアクセス

Program.cs を以下のように書きます。AWSアカウント内にあるS3バケット一覧を取得してバケット名をシンプルに表示するのみです。

using System;
using System.Threading;
using Amazon;
using Amazon.S3;

namespace S3Sample
{
    class S3Sample
    {
        static void Main(string[] args)
        {
            var client = new AmazonS3Client(RegionEndpoint.APNortheast1);
            var cancelToken = new CancellationToken();
            var task = client.ListBucketsAsync(cancelToken);
            var response = task.Result;
            foreach (var bucket in response.Buckets)
            {
                Console.WriteLine("{0}", bucket.BucketName);
            }
        }
    }
}

~/.aws/config を設定してあれば、これを実行するとバケット一覧を表示します。

~/.aws/config に複数のプロファイルがあるならば AWS_PROFILE という環境変数でプロファイル名を指定できます。

$ AWS_PROFILE=foo dotnet run 

エラーメッセージの例

~/.aws/configcredential_source = Ec2InstanceMetadata と書くことで、EC2インスタンスにアタッチされているIAMロールを使うようにしている場合は、以下のようなエラーになります。

Unhandled exception. System.AggregateException: One or more errors occurred. (Assembly AWSSDK.SecurityToken could not be found or loaded. This assembly must be available at runtime to use Amazon.Runtime.AssumeRoleAWSCredentials.)

以下のように追加のパッケージをインストールするとエラー解消します。

$ dotnet add package AWSSDK.SecurityToken

Google Cloud DataflowをJavaで動かしてみる

Google Cloud Dataflowを試してみるべく、GCP公式サイトにあるチュートリアルをやりました。

Java と Apache Maven を使用したクイックスタート  |  Cloud Dataflow  |  Google Cloud

このチュートリアルでダウンロードできるJavaソースコードはファイル数が多く、コードも長いです。そこで、処理内容はできるだけそのままで、必要な要素を理解しやすいようにコードを短くしたサンプルを作成しました。

前提

ソースコード

pom.xmlJavaソースの2ファイルです。

├── pom.xml
└── src
    └── main
        └── java
            └── org
                └── example
                    └── MySample.java

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>org.example</groupId>
  <artifactId>mysample</artifactId>
  <version>0.1</version>
  <packaging>jar</packaging>
  <properties>
    <beam.version>2.20.0</beam.version>
    <maven-compiler-plugin.version>3.7.0</maven-compiler-plugin.version>
    <slf4j.version>1.7.25</slf4j.version>
  </properties>

  <build>
    <plugins>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-compiler-plugin</artifactId>
        <version>${maven-compiler-plugin.version}</version>
        <configuration>
          <source>1.8</source>
          <target>1.8</target>
        </configuration>
      </plugin>
    </plugins>
  </build>

  <profiles>
    <profile>
      <id>dataflow-runner</id>
      <activation><activeByDefault>true</activeByDefault></activation>
      <dependencies>
        <dependency>
          <groupId>org.apache.beam</groupId>
          <artifactId>beam-runners-google-cloud-dataflow-java</artifactId>
          <version>${beam.version}</version>
          <scope>runtime</scope>
        </dependency>
      </dependencies>
    </profile>
  </profiles>

  <dependencies>
    <dependency>
      <groupId>org.apache.beam</groupId>
      <artifactId>beam-sdks-java-core</artifactId>
      <version>${beam.version}</version>
    </dependency>
    <dependency>
      <groupId>org.slf4j</groupId>
      <artifactId>slf4j-api</artifactId>
      <version>${slf4j.version}</version>
    </dependency>
    <dependency>
      <groupId>org.slf4j</groupId>
      <artifactId>slf4j-jdk14</artifactId>
      <version>${slf4j.version}</version>
      <scope>runtime</scope>
    </dependency>
  </dependencies>
</project>

MySample.java

package org.example;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.Validation.Required;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;

public class MySample {

  static class ExtractWordsFn extends DoFn<String, String> {
    @ProcessElement
    public void processElement(@Element String element, OutputReceiver<String> receiver) {
      String[] words = element.split(" ", -1);
      for (String word : words) {
        if (!word.isEmpty()) {
          receiver.output(word);
        }
      }
    }
  }

  static class CountWords
      extends PTransform<PCollection<String>, PCollection<KV<String, Long>>> {
    @Override
    public PCollection<KV<String, Long>> expand(PCollection<String> lines) {
      PCollection<String> words = lines.apply(ParDo.of(new ExtractWordsFn()));
      PCollection<KV<String, Long>> wordCounts = words.apply(Count.perElement());
      return wordCounts;
    }
  }

  static class FormatAsTextFn extends SimpleFunction<KV<String, Long>, String> {
    @Override
    public String apply(KV<String, Long> input) {
      return input.getKey() + ": " + input.getValue();
    }
  }

  public interface WordCountOptions extends PipelineOptions {
    @Description("Path of the file to read from")
    @Required
    String getInputPath();
    void setInputPath(String value);

    @Description("Path of the file to write to")
    @Required
    String getOutputPath();
    void setOutputPath(String value);
  }

  static void runWordCount(WordCountOptions options) {
    Pipeline p = Pipeline.create(options);

    String inputPath = options.getInputPath();
    String outputPath = options.getOutputPath();

    p.apply("ReadLines", TextIO.read().from(inputPath)).
      apply(new CountWords()).
      apply(MapElements.via(new FormatAsTextFn())).
      apply("WriteCounts", TextIO.write().to(outputPath));

    p.run().waitUntilFinish();
  }

  public static void main(String[] args) {
    WordCountOptions options =
        PipelineOptionsFactory.fromArgs(args).withValidation().as(WordCountOptions.class);

    runWordCount(options);
  }
}

実行方法

$ mvn compile exec:java -Dexec.mainClass=org.example.MySample -Dexec.args="--project=PROJECT_ID --region=us-central1 --stagingLocation=STAGING_GS_PATH --inputPath=gs://apache-beam-samples/shakespeare/kinglear.txt --outputPath=OUTPUT_GS_PATH --runner=DataflowRunner"

上記コマンドの中、以下の3箇所は自分の環境に合わせて書き換えます。

  • PROJECT_ID: GCPのプロジェクトID、
  • STAGING_GS_PATH: Dataflowの実行に必要なjarファイルなどを置くGCSのパス。このパスの中に自動で必要なファイルがアップロードされる。例: gs://sample-bucket/dataflow/staging/
  • OUTPUT_GS_PATH: 出力先GCSパス。 例: gs://sample-bucket/dataflow/output と指定すると gs://sample-bucket/dataflow/output-00000-of-00003 のような名前のファイルが作成される

コマンドの中で入力として指定している gs://apache-beam-samples/shakespeare/kinglear.txt にはApacheが用意しているサンプル入力ファイルがあり、誰でも読み込み可能です。

このコマンドを実行すると、3分弱で完了します。GCPコンソールでDataflowのジョブを見ると次のように表示されています。

f:id:suzuki-navi:20201006224734p:plain

以上。

fluent-plugin-s3のhex_randomはリトライしても値が変わらない

fluentdでS3に出力させるにはS3のキーがユニークになるようにする必要がある。ランダム値を使い、重複時はリトライをさせればユニークになるだろうと思ったが、hex_randomはリトライをしても値が変わらないので、ダメだった、という話。

time_sliceの時間内に1つしかファイル出力がないのであれば、fluentdのS3出力設定に以下のようにtime_sliceを含めれば、キーがユニークになる。time_sliceが1時間単位の設定ならば、1時間に最大1ファイルということになる。

s3_object_key_format %{path}/%{time_slice}/data.%{file_extension}

time_sliceの時間内に複数のファイル出力を想定したいのであれば、以下のようにindexを付けると、キーがユニークになる。

s3_object_key_format %{path}/%{time_slice}/data.%{index}.%{file_extension}

indexは最初は0で、S3にすでに0が存在すれば1、S3に0も1も存在すれば2になる。fluentdはS3出力時にS3へのファイル存在チェックをして、空いている番号を都度探す。indexが1以上になるケースが少ないのであれば、これでもいい。

ただし、time_slice時間内に常に多くのファイル出力がある場合は、S3へのファイル存在チェックのためS3 APIコール数が膨大になってしまう。以下の記事が参考になる。

fluentdでS3に転送しているならば、無駄なGETリクエストに気をつけろ! - Qiita

そこで、以下のように書いてみた。ランダム値があればキーの重複がめったに起こらないし、重複になればリトライしてもらえばよい、と思っていた。

s3_object_key_format %{path}/%{time_slice}/%{hex_random}.%{file_extension}

しかし、なんとhex_randomはリトライしても値が変わらないようになっているらしい。以下の記事で解説されている。

fluent-plugin-s3 の hex_random プレースホルダー - Qiita

これにより、実際に重複が起きてfluentdがまれにエラーを吐いてしまっていた。

結局、hex_randomindexを組み合わせるのがいいのだろう、という結論。

s3_object_key_format %{path}/%{time_slice}/%{hex_random}.%{index}.%{file_extension}

リンク

関連する私の記事

AtCoder参戦日記 ABC177 2020/08/29 #1 ― 初参加

8月終わりごろからAtCoderに参加しています。参加の記録を残しておこうと思い、この記事は初参加となった2020/08/29のABC177の記録です。

参加当日は、ペナルティのルールをよくわかっていなくて、C問題で適当に提出しすぎて、大量のペナルティで損をしました。

D問題はUnion-Findを知らなくて解けず、Cまでの3問解けたのみでした。

問題 結果 言語
A Scala
B Scala
C Java
D Java
E
F

A - Don't be late

問題

パッと書ける言語としてScalaで書きましたが、Pythonのほうが早かったかも。

import java.util.Scanner;

object Main extends App {
  val sc = new Scanner(System.in);
  val d, t, s = sc.nextInt();

  if (d <= s * t) {
    println("Yes");
  } else {
    println("No");
  }
}

B - Substring

問題

パッと書ける言語としてScalaで書きました。

import java.util.Scanner;

object Main extends App {
  val sc = new Scanner(System.in);
  val s, t = sc.next();

  val answer = (0 to (s.length - t.length)).map { d =>
    t.length - (0 until t.length).count { i =>
      s(i + d) == t(i)
    }
  }.min;

  println(answer);
}

C - Sum of product of pairs

問題

最初バカ正直な処理に時間のかかる方法で書いてしまい、またミスも繰り返してペナルティをたくさん受け取ってしまいました。

ScalaではちょっとめんどくさそうだったのでJavaになりました。

import java.util.Scanner;

class Main {
  public static void main(String[] args) {

    var m = 1000000007;

    var sc = new Scanner(System.in);

    var n = sc.nextInt();
    var a = new int[n];
    for (int i = 0; i < n; i++) {
      a[i] = sc.nextInt() % m;
    }

    long sum1 = 0;
    long sum2 = 0;
    for (int i = 0; i < n; i++) {
      sum1 += a[i];
      sum2 += (long)a[i] * a[i] % m;
    }
    sum1 = sum1 % m;
    sum2 = sum2 % m;

    long t = (sum1 * sum1 - sum2) % m;
    long answer;
    if (t % 2 == 0) {
      answer = t / 2;
    } else {
      answer = (t + m) / 2;
    }

    System.out.println(answer);
  }
}

D - Friends

問題

Union-Findというアルゴリズムを使います。

コンテスト中はUnion-Findを知らなくて、結局時間内に解けませんでした。以下のコードはコンテスト終了後に調べて書いたものです。

3問目(C)と同じくScalaではめんどくさそうだったのでJavaになりました。

class Main {
  public static int root(int[] root, int i) {
    int r = root[i];
    if (r == i) {
      return r;
    } else {
      int r2 = root(root, r);
      root[i] = r2;
      return r2;
    }
  }

  public static void main(String[] args) {

    var sc = new Scanner(System.in);

    int n = sc.nextInt();
    int m = sc.nextInt();

    var ab = new int[2 * m];
    for (int i = 0; i < m; i++) {
      int a = sc.nextInt() - 1;
      int b = sc.nextInt() - 1;
      if (a > b) {
        int t = b;
        b = a;
        a = t;
      }
      ab[2 * i    ] = a;
      ab[2 * i + 1] = b;
    }

    var count = new int[n];
    var root = new int[n];
    for (int i = 0; i < n; i++) {
      count[i] = 1;
      root[i] = i;
    }
    for (int i = 0; i < m; i++) {
      int a = ab[2 * i    ];
      int b = ab[2 * i + 1];
      int ra = root(root, a);
      int rb = root(root, b);
      if (ra != rb) {
        if (ra > rb) {
          int t = rb;
          rb = ra;
          ra = t;
        }
        root[rb] = ra;
        count[ra] += count[rb];
      }
    }

    int max = 0;
    for (int i = 0; i < n; i++) {
      int c = count[i];
      if (c > max) {
        max = c;
      }
    }

    System.out.println(max);

  }
}

MNISTのニューラルネットワークでの学習の教師データ数と精度の関係

MNISTには全部で6万個の訓練用データがありますが、単純なニューラルネットワークで学習させたときの訓練用データの数と学習結果の精度の関係を見てみました。

size process_time train_loss train_accuracy test_loss test_accuracy
60000 76.70 0.0051 0.9988 0.1270 0.9752
15000 23.22 0.0038 0.9999 0.1766 0.9609
3750 10.67 0.0143 0.9997 0.2706 0.9268
950 7.61 0.0296 1.0000 0.4152 0.8790
250 7.28 0.0661 0.9960 0.7151 0.7821

左から、訓練用データサイズ、処理時間(秒)、訓練データでの損失関数値、訓練データでの精度、テストデータでの損失関数値、テストデータでの精度です。

学習推移のグラフ

size=60000

f:id:suzuki-navi:20200928193609p:plain

size=15000

f:id:suzuki-navi:20200928193630p:plain

size=3750

f:id:suzuki-navi:20200928193654p:plain

size=950

f:id:suzuki-navi:20200928193710p:plain

size=250

f:id:suzuki-navi:20200928193723p:plain

Pythonコード

Google Colaboratoryで実行しました。

import time
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist

plt.rcParams['figure.figsize'] = (16.0, 7.0)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 入力と出力サイズ
in_size = 28 * 28
out_size = 10

# モデル構造を定義
def createModel():
  hidden_size = 64
  model = tf.keras.models.Sequential()
  model.add(tf.keras.layers.Dense(hidden_size, activation='relu', input_shape=(in_size,)))
  model.add(tf.keras.layers.Dense(out_size, activation='softmax'))
  return model

# 学習の様子をグラフへ描画
def plotLearning(result):
  fig = plt.figure()

  xs = range(1, len(result.history['loss']) + 1)

  # ロスの推移をプロット
  ax1 = fig.add_subplot(1, 1, 1)
  ax1.set_ylim(0, 2.5)
  ax1.set_ylabel('Loss')
  ax1.plot(xs, result.history['loss'])
  ax1.plot(xs, result.history['val_loss'])

  # 正解率の推移をプロット
  ax2 = ax1.twinx()
  ax2.set_ylim(0.75, 1.0)
  ax2.set_ylabel('Accuracy')
  ax2.plot(xs, result.history['accuracy'])
  ax2.plot(xs, result.history['val_accuracy'])

  plt.title('Loss & Accuracy')
  plt.legend(['train', 'test'], loc='upper left')
  plt.show()

def calc(train_size, epochs):
  model = createModel()

  # モデルを構築
  model.compile(
      loss = "categorical_crossentropy",
      optimizer = "adam",
      metrics=["accuracy"])

  start = time.time()
  print("train_size: %d, epochs: %d" % (train_size, epochs))
  x_train_reshape = x_train.reshape(-1, in_size).astype('float32') / 255
  x_test_reshape = x_test.reshape(-1, in_size).astype('float32') / 255
  y_train_onehot = tf.keras.backend.one_hot(y_train, out_size)
  y_test_onehot = tf.keras.backend.one_hot(y_test, out_size)

  x_train_reshape = x_train_reshape[:train_size]
  y_train_onehot = y_train_onehot[:train_size]

  # 学習を実行
  result = model.fit(x_train_reshape, y_train_onehot,
      batch_size=50,
      epochs=epochs,
      verbose=1,
      validation_data=(x_test_reshape, y_test_onehot))

  processTime = time.time() - start
  print("train_size: %d, epochs: %d, processTime: %f" % (train_size, epochs, processTime))
  plotLearning(result)

  return "| %d | %5.2f | %6.4f | %6.4f | %6.4f | %6.4f |" % (train_size, processTime,
    result.history['loss'][-1], result.history['accuracy'][-1],
    result.history['val_loss'][-1], result.history['val_accuracy'][-1])

table = []
table.append(calc(60000, 30))
table.append(calc(15000, 30))
table.append(calc(3750, 30))
table.append(calc(950, 30))
table.append(calc(250, 30))
for s in table:
  print(s)

2020/10/12 追記

ニューロン数と精度の関係も見てみました。