diff ai.go @ 0:43e580fa4719

first commit.
author pyon@macmini
date Mon, 04 Sep 2017 21:40:33 +0900
parents
children c32b619844ba
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ai.go	Mon Sep 04 21:40:33 2017 +0900
@@ -0,0 +1,370 @@
+package main
+
+import (
+	"fmt"
+	"io/ioutil"
+    "math"
+    "math/rand"
+    "log"
+	"os"
+    "sort"
+	"strconv"
+	"strings"
+    "time"
+)
+
+// パラメタ変数
+var plabel []string
+var w1_11, w1_12, w1_21, w1_22 []float64
+var b1_1, b1_2 []float64
+var h1_11, h1_12, h1_21, h1_22 []float64
+var w2_11, w2_21 []float64
+var b2_1 []float64
+var s1, s2 float64
+
+// パラメタ以外の変数
+const st int = 10 // とりあえず 10セット試す
+const dd int = 7 // とりあえず 7日を予想
+const pp int = 2009 // とりあえず pp種類のパラメタ
+var x1, x2 float64
+var y [st][dd][pp]float64
+var e [st][pp]float64
+var lstart [st]string   // 学習データの開始日
+var lx1, lx2 [st]float64
+var lans [st][dd]float64
+var ps, sw time.Time
+
+// 評価ソート用
+type Es struct {
+    idx int
+    val float64
+}
+var ea [pp]Es
+
+func init() {
+    ps = time.Now()
+    lstart[0] = "2007-04-02"
+    lstart[1] = "2007-10-01"
+    lstart[2] = "2008-04-01"
+    lstart[3] = "2008-10-01"
+    lstart[4] = "2009-04-01"
+    lstart[5] = "2009-10-01"
+    lstart[6] = "2010-04-01"
+    lstart[7] = "2010-10-01"
+    lstart[8] = "2011-04-01"
+    lstart[9] = "2011-10-03"
+}
+
+// メイン
+func main() {
+
+    dfile := "ai_input.dat"
+	pfile := "param.dat"
+
+    // データ入力
+	if err := read_data( dfile ); err != nil {
+		fmt.Fprintf( os.Stderr, "%v\n", err )
+		os.Exit( 1 )
+	}
+    swatch( "data input done." )
+
+    // パラメタ入力
+	if err := read_param( pfile ); err != nil {
+		fmt.Fprintf( os.Stderr, "%v\n", err )
+		os.Exit( 1 )
+	}
+    swatch( "param read done." )
+
+    // 計算 & 評価
+    calc_eval()
+    swatch( "calc & eval done." )
+
+    // 初期値出力
+	mfile := "data/s00.dat"
+	if err := save_tmp( mfile, true, false ); err != nil {
+		fmt.Fprintf( os.Stderr, "%v\n", err )
+		log.Fatal( err )
+	}
+    swatch( "mid-save done." )
+
+    // パラメタ書換えつつ探索
+    for n := 0; n < 200; n++ {
+        update()
+        swatch( "update done." )
+
+        calc_eval()
+        swatch( "calc & eval done." )
+        mfile = fmt.Sprintf( "data/m%02d.dat", n )
+        if err := save_tmp( mfile, true, false ); err != nil {
+            fmt.Fprintf( os.Stderr, "%v\n", err )
+            log.Fatal( err )
+        }
+        swatch( "mid-save done." )
+    }
+
+    // 最終出力
+    /*
+    ofile := "data/out.dat"
+	if err := save_result( ofile, 0, 0.0 ); err != nil {
+		fmt.Fprintf( os.Stderr, "%v\n", err )
+		os.Exit( 1 )
+	}
+    */
+    swatch( "finish.." )
+}
+
+
+// データの読込み
+//#通番 年月日    始値   高値   安値   終値   3日平均 7日平均 30日平均
+//0001 2007-04-02 117.84 118.08 117.46 117.84 117.84 117.84 117.84
+func read_data( file string ) error {
+	data, err := ioutil.ReadFile( file )
+	if err != nil {
+		return err
+	}
+
+    // 学習データ
+    j, skip := 0, 0
+    var lsidx [st]int   // learning-set start index
+    var value []string
+	for i, line := range strings.Split( string( data ), "\n" ) {
+        if strings.HasPrefix( line, "#" ) {
+            skip++
+        } else {
+            buf := strings.Fields( line )
+            if len( buf ) == 0 {
+                skip++
+                continue
+            }
+            value = append( value, buf[5] )
+            if j < st && buf[1] == lstart[j] {
+                lsidx[j] = i - skip
+                j++
+            }
+        }
+	}
+
+    for i, _ := range lstart {
+        var f float64
+        f, _ = strconv.ParseFloat( value[ lsidx[i] + 1], 64 ); lx1[i] = f
+        f, _ = strconv.ParseFloat( value[ lsidx[i] ],    64 ); lx2[i] = f
+        for j := 0; j < dd; j++ {
+            f, _ = strconv.ParseFloat( value[ lsidx[i] + j + 2 ], 64 ); lans[i][j] = f
+        }
+    }
+
+    // 検証データ
+
+	return nil
+}
+
+// パラメタの読込み
+func read_param( file string ) error {
+	data, err := ioutil.ReadFile( file )
+	if err != nil {
+		return err
+	}
+	for _, line := range strings.Split( string( data ), "\n" ) {
+        if !strings.HasPrefix( line, "#" ) {
+            buf := strings.Fields( line )
+            if len( buf ) == 0 {
+                continue
+            }
+            plabel = append( plabel, buf[0] )
+
+            var f float64
+            f, _ = strconv.ParseFloat( buf[1], 64 ); w1_11 = append( w1_11, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[2], 64 ); w1_12 = append( w1_12, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[3], 64 ); w1_21 = append( w1_21, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[4], 64 ); w1_22 = append( w1_22, f * 0.1 )
+
+            f, _ = strconv.ParseFloat( buf[5], 64 ); b1_1 = append( b1_1, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[6], 64 ); b1_2 = append( b1_2, f * 0.1 )
+
+            f, _ = strconv.ParseFloat( buf[7],  64 ); h1_11 = append( h1_11, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[8],  64 ); h1_12 = append( h1_12, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[9],  64 ); h1_21 = append( h1_21, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[10], 64 ); h1_22 = append( h1_22, f * 0.1 )
+
+            f, _ = strconv.ParseFloat( buf[11], 64 ); w2_11 = append( w2_11, f * 0.1 )
+            f, _ = strconv.ParseFloat( buf[12], 64 ); w2_21 = append( w2_21, f * 0.1 )
+
+            f, _ = strconv.ParseFloat( buf[13], 64 ); b2_1 = append( b2_1, f * 0.1 )
+        }
+	}
+	return nil
+}
+
+// 結果を保存
+func save_tmp( file string, save, detail bool ) error {
+    if !save {
+        return nil
+    }
+
+    os.Remove( file )
+	f, err := os.OpenFile( file, os.O_CREATE|os.O_WRONLY, 0644 )
+	if err != nil {
+        return err
+	}
+    if detail {
+        for i := 0; i < st; i++ {
+            for p := 0; p < pp; p++ {
+                s := fmt.Sprintf( "e = %10.2f / %2d:%s\n", e[i][p], i, plabel[p] )
+                f.WriteString( s )
+                for d := 0; d < dd; d++ {
+                    s = fmt.Sprintf( "%10.2f ( %.2f )\n", y[i][d][p], lans[i][d] )
+                    f.WriteString( s )
+                }
+            }
+            f.WriteString( "----\n" );
+        }
+    }
+    for p := 0; p < pp; p++ {
+        s := fmt.Sprintf( "ea = %7.2f / EA:%s [ ", ea[p].val, plabel[p] )
+        s += fmt.Sprintf( "%3.1f %3.1f %3.1f %3.1f ", w1_11[p], w1_12[p], w1_21[p], w1_22[p] )
+        s += fmt.Sprintf( "%3.1f %3.1f ", b1_1[p], b1_2[p] )
+        s += fmt.Sprintf( "%3.1f %3.1f %3.1f %3.1f ", h1_11[p], h1_12[p], h1_21[p], h1_22[p] )
+        s += fmt.Sprintf( "%3.1f %3.1f ", w2_11[p], w2_21[p] )
+        s += fmt.Sprintf( "%3.1f ]\n", b2_1[p] )
+        f.WriteString( s )
+    }
+	if err := f.Close(); err != nil {
+        return err
+	}
+    return nil
+}
+
+func save_result( file string, line int, result float64 ) error {
+	return nil
+}
+
+// 計算と評価
+func calc_eval() {
+    for p := 0; p < pp; p++ {
+        ea[p].idx = p
+        ea[p].val = 0.0
+    }
+    for i := 0; i < st; i++ {
+        for p := 0; p < pp; p++ {
+            initialize( i )
+            for d := 0; d < dd; d++ {
+                y[i][d][p] = do_calc( p )
+            }
+            e[i][p] = evaluate( i, p )
+            ea[p].val += e[i][p]
+        }
+	}
+    for p := 0; p < pp; p++ {
+        if ea[p].val > 10000 {
+            ea[p].val = 9999.99
+        }
+    }
+}
+
+// アルゴリズム
+func initialize( i int ) {
+	x1, x2 = lx1[i], lx2[i]
+	s1, s2 = 0.0, 0.0
+}
+
+func do_calc( p int ) float64 {
+
+	a1 := x1 * w1_11[p] + x2 * w1_21[p] + s1 * h1_11[p] + s2 * h1_21[p] + b1_1[p]
+	a2 := x1 * w1_12[p] + x2 * w1_22[p] + s1 * h1_12[p] + s2 * h1_22[p] + b1_2[p]
+
+	// ReLU
+	if a1 < 0 {
+		s1 = 0
+	} else {
+		s1 = a1
+	}
+	if a2 < 0 {
+		s2 = 0
+	} else {
+		s2 = a2
+	}
+
+	y1 := a1 * w2_11[p] + a2 * w2_21[p] + b2_1[p]
+
+	x2 = x1
+	x1 = y1
+
+	return y1
+}
+
+// 評価関数
+func evaluate( i, p int ) float64 {
+    var e float64
+    for d := 0; d < dd; d++ {
+        e += ( lans[i][d] - y[i][d][p] ) * ( lans[i][d] - y[i][d][p] )
+    }
+    e /= float64(dd)
+    return e
+}
+
+// パラメタ更新
+func update() {
+    // backup
+    var bplabel []string
+    var bw1_11, bw1_12, bw1_21, bw1_22 []float64
+    var bb1_1, bb1_2 []float64
+    var bh1_11, bh1_12, bh1_21, bh1_22 []float64
+    var bw2_11, bw2_21 []float64
+    var bb2_1 []float64
+
+    es := ea[:]
+    sort.SliceStable( es, func( i, j int ) bool { return es[i].val < es[j].val } )
+    for i := 0; i < 7; i++ {
+        j := es[i].idx
+        bplabel = append( bplabel, plabel[j] )
+
+        bw1_11 = append( bw1_11, w1_11[j] )
+        bw1_12 = append( bw1_12, w1_12[j] )
+        bw1_21 = append( bw1_21, w1_21[j] )
+        bw1_22 = append( bw1_22, w1_22[j] )
+
+        bb1_1 = append( bb1_1, b1_1[j] )
+        bb1_2 = append( bb1_2, b1_2[j] )
+
+        bh1_11 = append( bh1_11, h1_11[j] )
+        bh1_12 = append( bh1_12, h1_12[j] )
+        bh1_21 = append( bh1_21, h1_21[j] )
+        bh1_22 = append( bh1_22, h1_22[j] )
+
+        bw2_11 = append( bw2_11, w2_11[j] )
+        bw2_21 = append( bw2_21, w2_21[j] )
+
+        bb2_1 = append( bb2_1, b2_1[j] )
+    }
+
+	rand.Seed( time.Now().UnixNano() )
+    for p := 0; p < pp; p++ {
+        plabel[p] = fmt.Sprintf( "%04d", p )
+        i := int( math.Mod( float64( p ), 7 ) )
+
+        w1_11[p] = bw1_11[i] + float64( rand.Intn( 3 ) - 1 )
+        w1_12[p] = bw1_12[i] + float64( rand.Intn( 3 ) - 1 )
+        w1_21[p] = bw1_21[i] + float64( rand.Intn( 3 ) - 1 )
+        w1_22[p] = bw1_22[i] + float64( rand.Intn( 3 ) - 1 )
+
+        b1_1[p] = bb1_1[i] + float64( rand.Intn( 3 ) - 1 )
+        b1_2[p] = bb1_2[i] + float64( rand.Intn( 3 ) - 1 )
+
+        h1_11[p] = bh1_11[i] + float64( rand.Intn( 3 ) - 1 )
+        h1_12[p] = bh1_12[i] + float64( rand.Intn( 3 ) - 1 )
+        h1_21[p] = bh1_21[i] + float64( rand.Intn( 3 ) - 1 )
+        h1_22[p] = bh1_22[i] + float64( rand.Intn( 3 ) - 1 )
+
+        w2_11[p] = bw2_11[i] + float64( rand.Intn( 3 ) - 1 )
+        w2_21[p] = bw2_21[i] + float64( rand.Intn( 3 ) - 1 )
+
+        b2_1[p] = bb2_1[i] + float64( rand.Intn( 3 ) - 1 )
+    }
+}
+
+// 時間計測
+func swatch( s string ) {
+    sw = time.Now()
+    fmt.Fprintf( os.Stderr, "[ %v ( %v ) ]\t%s.\n", time.Since( ps ), time.Since( sw ), s )
+}
+