From 5150caf5155452c1f858755f7bf925f1ea5fd474 Mon Sep 17 00:00:00 2001 From: Will Charczuk Date: Mon, 17 Apr 2017 23:08:46 -0700 Subject: [PATCH] lu decomp implementation. --- matrix/matrix.go | 60 +++++++++++++++++++++++++++++++++++++++++++ matrix/matrix_test.go | 15 +++++++++++ 2 files changed, 75 insertions(+) diff --git a/matrix/matrix.go b/matrix/matrix.go index 51a3da0..c94768b 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -324,8 +324,68 @@ func (m *Matrix) Multiply(m2 *Matrix) (m3 *Matrix, err error) { return } +// Pivotize does something i'm not sure what. +func (m *Matrix) Pivotize() *Matrix { + pv := make([]int, m.stride) + + for i := range pv { + pv[i] = i + } + + for j, dx := 0, 0; j < m.stride; j++ { + row := j + max := m.elements[dx] + for i, ixcj := j, dx; i < m.stride; i++ { + if m.elements[ixcj] > max { + max = m.elements[ixcj] + row = i + } + ixcj += m.stride + } + if j != row { + pv[row], pv[j] = pv[j], pv[row] + } + dx += m.stride + 1 + } + p := Zero(m.stride, m.stride) + for r, c := range pv { + p.elements[r*m.stride+c] = 1 + } + return p +} + // Decompositions +// LU performs the LU decomposition. +func (m *Matrix) LU() (l, u, p *Matrix) { + l = Zero(m.stride, m.stride) + u = Zero(m.stride, m.stride) + p = m.Pivotize() + m, _ = p.Multiply(m) + for j, jxc0 := 0, 0; j < m.stride; j++ { + l.elements[jxc0+j] = 1 + for i, ixc0 := 0, 0; ixc0 <= jxc0; i++ { + sum := 0. + for k, kxcj := 0, j; k < i; k++ { + sum += u.elements[kxcj] * l.elements[ixc0+k] + kxcj += m.stride + } + u.elements[ixc0+j] = m.elements[ixc0+j] - sum + ixc0 += m.stride + } + for ixc0 := jxc0; ixc0 < len(m.elements); ixc0 += m.stride { + sum := 0. + for k, kxcj := 0, j; k < j; k++ { + sum += u.elements[kxcj] * l.elements[ixc0+k] + kxcj += m.stride + } + l.elements[ixc0+j] = (m.elements[ixc0+j] - sum) / u.elements[jxc0+j] + } + jxc0 += m.stride + } + return +} + // QR performs the qr decomposition. func (m *Matrix) QR() (q, r *Matrix) { defer func() { diff --git a/matrix/matrix_test.go b/matrix/matrix_test.go index c58c5ba..52475f2 100644 --- a/matrix/matrix_test.go +++ b/matrix/matrix_test.go @@ -329,6 +329,21 @@ func TestMatrixString(t *testing.T) { assert.Equal("1 2 3 \n4 5 6 \n7 8 9 \n", m.String()) } +func TestMatrixLU(t *testing.T) { + assert := assert.New(t) + + m := NewFromArrays([][]float64{ + {1, 3, 5}, + {2, 4, 7}, + {1, 1, 0}, + }) + + l, u, p := m.LU() + assert.NotNil(l) + assert.NotNil(u) + assert.NotNil(p) +} + func TestMatrixQR(t *testing.T) { assert := assert.New(t)