Mercurial > mercurial > hgweb_ai.cgi
diff ai.go @ 0:43e580fa4719
first commit.
author | pyon@macmini |
---|---|
date | Mon, 04 Sep 2017 21:40:33 +0900 |
parents | |
children | c32b619844ba |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/ai.go Mon Sep 04 21:40:33 2017 +0900 @@ -0,0 +1,370 @@ +package main + +import ( + "fmt" + "io/ioutil" + "math" + "math/rand" + "log" + "os" + "sort" + "strconv" + "strings" + "time" +) + +// パラメタ変数 +var plabel []string +var w1_11, w1_12, w1_21, w1_22 []float64 +var b1_1, b1_2 []float64 +var h1_11, h1_12, h1_21, h1_22 []float64 +var w2_11, w2_21 []float64 +var b2_1 []float64 +var s1, s2 float64 + +// パラメタ以外の変数 +const st int = 10 // とりあえず 10セット試す +const dd int = 7 // とりあえず 7日を予想 +const pp int = 2009 // とりあえず pp種類のパラメタ +var x1, x2 float64 +var y [st][dd][pp]float64 +var e [st][pp]float64 +var lstart [st]string // 学習データの開始日 +var lx1, lx2 [st]float64 +var lans [st][dd]float64 +var ps, sw time.Time + +// 評価ソート用 +type Es struct { + idx int + val float64 +} +var ea [pp]Es + +func init() { + ps = time.Now() + lstart[0] = "2007-04-02" + lstart[1] = "2007-10-01" + lstart[2] = "2008-04-01" + lstart[3] = "2008-10-01" + lstart[4] = "2009-04-01" + lstart[5] = "2009-10-01" + lstart[6] = "2010-04-01" + lstart[7] = "2010-10-01" + lstart[8] = "2011-04-01" + lstart[9] = "2011-10-03" +} + +// メイン +func main() { + + dfile := "ai_input.dat" + pfile := "param.dat" + + // データ入力 + if err := read_data( dfile ); err != nil { + fmt.Fprintf( os.Stderr, "%v\n", err ) + os.Exit( 1 ) + } + swatch( "data input done." ) + + // パラメタ入力 + if err := read_param( pfile ); err != nil { + fmt.Fprintf( os.Stderr, "%v\n", err ) + os.Exit( 1 ) + } + swatch( "param read done." ) + + // 計算 & 評価 + calc_eval() + swatch( "calc & eval done." ) + + // 初期値出力 + mfile := "data/s00.dat" + if err := save_tmp( mfile, true, false ); err != nil { + fmt.Fprintf( os.Stderr, "%v\n", err ) + log.Fatal( err ) + } + swatch( "mid-save done." ) + + // パラメタ書換えつつ探索 + for n := 0; n < 200; n++ { + update() + swatch( "update done." ) + + calc_eval() + swatch( "calc & eval done." ) + mfile = fmt.Sprintf( "data/m%02d.dat", n ) + if err := save_tmp( mfile, true, false ); err != nil { + fmt.Fprintf( os.Stderr, "%v\n", err ) + log.Fatal( err ) + } + swatch( "mid-save done." ) + } + + // 最終出力 + /* + ofile := "data/out.dat" + if err := save_result( ofile, 0, 0.0 ); err != nil { + fmt.Fprintf( os.Stderr, "%v\n", err ) + os.Exit( 1 ) + } + */ + swatch( "finish.." ) +} + + +// データの読込み +//#通番 年月日 始値 高値 安値 終値 3日平均 7日平均 30日平均 +//0001 2007-04-02 117.84 118.08 117.46 117.84 117.84 117.84 117.84 +func read_data( file string ) error { + data, err := ioutil.ReadFile( file ) + if err != nil { + return err + } + + // 学習データ + j, skip := 0, 0 + var lsidx [st]int // learning-set start index + var value []string + for i, line := range strings.Split( string( data ), "\n" ) { + if strings.HasPrefix( line, "#" ) { + skip++ + } else { + buf := strings.Fields( line ) + if len( buf ) == 0 { + skip++ + continue + } + value = append( value, buf[5] ) + if j < st && buf[1] == lstart[j] { + lsidx[j] = i - skip + j++ + } + } + } + + for i, _ := range lstart { + var f float64 + f, _ = strconv.ParseFloat( value[ lsidx[i] + 1], 64 ); lx1[i] = f + f, _ = strconv.ParseFloat( value[ lsidx[i] ], 64 ); lx2[i] = f + for j := 0; j < dd; j++ { + f, _ = strconv.ParseFloat( value[ lsidx[i] + j + 2 ], 64 ); lans[i][j] = f + } + } + + // 検証データ + + return nil +} + +// パラメタの読込み +func read_param( file string ) error { + data, err := ioutil.ReadFile( file ) + if err != nil { + return err + } + for _, line := range strings.Split( string( data ), "\n" ) { + if !strings.HasPrefix( line, "#" ) { + buf := strings.Fields( line ) + if len( buf ) == 0 { + continue + } + plabel = append( plabel, buf[0] ) + + var f float64 + f, _ = strconv.ParseFloat( buf[1], 64 ); w1_11 = append( w1_11, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[2], 64 ); w1_12 = append( w1_12, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[3], 64 ); w1_21 = append( w1_21, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[4], 64 ); w1_22 = append( w1_22, f * 0.1 ) + + f, _ = strconv.ParseFloat( buf[5], 64 ); b1_1 = append( b1_1, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[6], 64 ); b1_2 = append( b1_2, f * 0.1 ) + + f, _ = strconv.ParseFloat( buf[7], 64 ); h1_11 = append( h1_11, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[8], 64 ); h1_12 = append( h1_12, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[9], 64 ); h1_21 = append( h1_21, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[10], 64 ); h1_22 = append( h1_22, f * 0.1 ) + + f, _ = strconv.ParseFloat( buf[11], 64 ); w2_11 = append( w2_11, f * 0.1 ) + f, _ = strconv.ParseFloat( buf[12], 64 ); w2_21 = append( w2_21, f * 0.1 ) + + f, _ = strconv.ParseFloat( buf[13], 64 ); b2_1 = append( b2_1, f * 0.1 ) + } + } + return nil +} + +// 結果を保存 +func save_tmp( file string, save, detail bool ) error { + if !save { + return nil + } + + os.Remove( file ) + f, err := os.OpenFile( file, os.O_CREATE|os.O_WRONLY, 0644 ) + if err != nil { + return err + } + if detail { + for i := 0; i < st; i++ { + for p := 0; p < pp; p++ { + s := fmt.Sprintf( "e = %10.2f / %2d:%s\n", e[i][p], i, plabel[p] ) + f.WriteString( s ) + for d := 0; d < dd; d++ { + s = fmt.Sprintf( "%10.2f ( %.2f )\n", y[i][d][p], lans[i][d] ) + f.WriteString( s ) + } + } + f.WriteString( "----\n" ); + } + } + for p := 0; p < pp; p++ { + s := fmt.Sprintf( "ea = %7.2f / EA:%s [ ", ea[p].val, plabel[p] ) + s += fmt.Sprintf( "%3.1f %3.1f %3.1f %3.1f ", w1_11[p], w1_12[p], w1_21[p], w1_22[p] ) + s += fmt.Sprintf( "%3.1f %3.1f ", b1_1[p], b1_2[p] ) + s += fmt.Sprintf( "%3.1f %3.1f %3.1f %3.1f ", h1_11[p], h1_12[p], h1_21[p], h1_22[p] ) + s += fmt.Sprintf( "%3.1f %3.1f ", w2_11[p], w2_21[p] ) + s += fmt.Sprintf( "%3.1f ]\n", b2_1[p] ) + f.WriteString( s ) + } + if err := f.Close(); err != nil { + return err + } + return nil +} + +func save_result( file string, line int, result float64 ) error { + return nil +} + +// 計算と評価 +func calc_eval() { + for p := 0; p < pp; p++ { + ea[p].idx = p + ea[p].val = 0.0 + } + for i := 0; i < st; i++ { + for p := 0; p < pp; p++ { + initialize( i ) + for d := 0; d < dd; d++ { + y[i][d][p] = do_calc( p ) + } + e[i][p] = evaluate( i, p ) + ea[p].val += e[i][p] + } + } + for p := 0; p < pp; p++ { + if ea[p].val > 10000 { + ea[p].val = 9999.99 + } + } +} + +// アルゴリズム +func initialize( i int ) { + x1, x2 = lx1[i], lx2[i] + s1, s2 = 0.0, 0.0 +} + +func do_calc( p int ) float64 { + + a1 := x1 * w1_11[p] + x2 * w1_21[p] + s1 * h1_11[p] + s2 * h1_21[p] + b1_1[p] + a2 := x1 * w1_12[p] + x2 * w1_22[p] + s1 * h1_12[p] + s2 * h1_22[p] + b1_2[p] + + // ReLU + if a1 < 0 { + s1 = 0 + } else { + s1 = a1 + } + if a2 < 0 { + s2 = 0 + } else { + s2 = a2 + } + + y1 := a1 * w2_11[p] + a2 * w2_21[p] + b2_1[p] + + x2 = x1 + x1 = y1 + + return y1 +} + +// 評価関数 +func evaluate( i, p int ) float64 { + var e float64 + for d := 0; d < dd; d++ { + e += ( lans[i][d] - y[i][d][p] ) * ( lans[i][d] - y[i][d][p] ) + } + e /= float64(dd) + return e +} + +// パラメタ更新 +func update() { + // backup + var bplabel []string + var bw1_11, bw1_12, bw1_21, bw1_22 []float64 + var bb1_1, bb1_2 []float64 + var bh1_11, bh1_12, bh1_21, bh1_22 []float64 + var bw2_11, bw2_21 []float64 + var bb2_1 []float64 + + es := ea[:] + sort.SliceStable( es, func( i, j int ) bool { return es[i].val < es[j].val } ) + for i := 0; i < 7; i++ { + j := es[i].idx + bplabel = append( bplabel, plabel[j] ) + + bw1_11 = append( bw1_11, w1_11[j] ) + bw1_12 = append( bw1_12, w1_12[j] ) + bw1_21 = append( bw1_21, w1_21[j] ) + bw1_22 = append( bw1_22, w1_22[j] ) + + bb1_1 = append( bb1_1, b1_1[j] ) + bb1_2 = append( bb1_2, b1_2[j] ) + + bh1_11 = append( bh1_11, h1_11[j] ) + bh1_12 = append( bh1_12, h1_12[j] ) + bh1_21 = append( bh1_21, h1_21[j] ) + bh1_22 = append( bh1_22, h1_22[j] ) + + bw2_11 = append( bw2_11, w2_11[j] ) + bw2_21 = append( bw2_21, w2_21[j] ) + + bb2_1 = append( bb2_1, b2_1[j] ) + } + + rand.Seed( time.Now().UnixNano() ) + for p := 0; p < pp; p++ { + plabel[p] = fmt.Sprintf( "%04d", p ) + i := int( math.Mod( float64( p ), 7 ) ) + + w1_11[p] = bw1_11[i] + float64( rand.Intn( 3 ) - 1 ) + w1_12[p] = bw1_12[i] + float64( rand.Intn( 3 ) - 1 ) + w1_21[p] = bw1_21[i] + float64( rand.Intn( 3 ) - 1 ) + w1_22[p] = bw1_22[i] + float64( rand.Intn( 3 ) - 1 ) + + b1_1[p] = bb1_1[i] + float64( rand.Intn( 3 ) - 1 ) + b1_2[p] = bb1_2[i] + float64( rand.Intn( 3 ) - 1 ) + + h1_11[p] = bh1_11[i] + float64( rand.Intn( 3 ) - 1 ) + h1_12[p] = bh1_12[i] + float64( rand.Intn( 3 ) - 1 ) + h1_21[p] = bh1_21[i] + float64( rand.Intn( 3 ) - 1 ) + h1_22[p] = bh1_22[i] + float64( rand.Intn( 3 ) - 1 ) + + w2_11[p] = bw2_11[i] + float64( rand.Intn( 3 ) - 1 ) + w2_21[p] = bw2_21[i] + float64( rand.Intn( 3 ) - 1 ) + + b2_1[p] = bb2_1[i] + float64( rand.Intn( 3 ) - 1 ) + } +} + +// 時間計測 +func swatch( s string ) { + sw = time.Now() + fmt.Fprintf( os.Stderr, "[ %v ( %v ) ]\t%s.\n", time.Since( ps ), time.Since( sw ), s ) +} +