lu decomp implementation.

This commit is contained in:
Will Charczuk 2017-04-17 23:08:46 -07:00
parent e4f05d3863
commit 5150caf515
2 changed files with 75 additions and 0 deletions

View file

@ -324,8 +324,68 @@ func (m *Matrix) Multiply(m2 *Matrix) (m3 *Matrix, err error) {
return
}
// Pivotize does something i'm not sure what.
func (m *Matrix) Pivotize() *Matrix {
pv := make([]int, m.stride)
for i := range pv {
pv[i] = i
}
for j, dx := 0, 0; j < m.stride; j++ {
row := j
max := m.elements[dx]
for i, ixcj := j, dx; i < m.stride; i++ {
if m.elements[ixcj] > max {
max = m.elements[ixcj]
row = i
}
ixcj += m.stride
}
if j != row {
pv[row], pv[j] = pv[j], pv[row]
}
dx += m.stride + 1
}
p := Zero(m.stride, m.stride)
for r, c := range pv {
p.elements[r*m.stride+c] = 1
}
return p
}
// Decompositions
// LU performs the LU decomposition.
func (m *Matrix) LU() (l, u, p *Matrix) {
l = Zero(m.stride, m.stride)
u = Zero(m.stride, m.stride)
p = m.Pivotize()
m, _ = p.Multiply(m)
for j, jxc0 := 0, 0; j < m.stride; j++ {
l.elements[jxc0+j] = 1
for i, ixc0 := 0, 0; ixc0 <= jxc0; i++ {
sum := 0.
for k, kxcj := 0, j; k < i; k++ {
sum += u.elements[kxcj] * l.elements[ixc0+k]
kxcj += m.stride
}
u.elements[ixc0+j] = m.elements[ixc0+j] - sum
ixc0 += m.stride
}
for ixc0 := jxc0; ixc0 < len(m.elements); ixc0 += m.stride {
sum := 0.
for k, kxcj := 0, j; k < j; k++ {
sum += u.elements[kxcj] * l.elements[ixc0+k]
kxcj += m.stride
}
l.elements[ixc0+j] = (m.elements[ixc0+j] - sum) / u.elements[jxc0+j]
}
jxc0 += m.stride
}
return
}
// QR performs the qr decomposition.
func (m *Matrix) QR() (q, r *Matrix) {
defer func() {

View file

@ -329,6 +329,21 @@ func TestMatrixString(t *testing.T) {
assert.Equal("1 2 3 \n4 5 6 \n7 8 9 \n", m.String())
}
func TestMatrixLU(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 3, 5},
{2, 4, 7},
{1, 1, 0},
})
l, u, p := m.LU()
assert.NotNil(l)
assert.NotNil(u)
assert.NotNil(p)
}
func TestMatrixQR(t *testing.T) {
assert := assert.New(t)