lu decomp implementation.
This commit is contained in:
parent
e4f05d3863
commit
5150caf515
2 changed files with 75 additions and 0 deletions
|
@ -324,8 +324,68 @@ func (m *Matrix) Multiply(m2 *Matrix) (m3 *Matrix, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// Pivotize does something i'm not sure what.
|
||||
func (m *Matrix) Pivotize() *Matrix {
|
||||
pv := make([]int, m.stride)
|
||||
|
||||
for i := range pv {
|
||||
pv[i] = i
|
||||
}
|
||||
|
||||
for j, dx := 0, 0; j < m.stride; j++ {
|
||||
row := j
|
||||
max := m.elements[dx]
|
||||
for i, ixcj := j, dx; i < m.stride; i++ {
|
||||
if m.elements[ixcj] > max {
|
||||
max = m.elements[ixcj]
|
||||
row = i
|
||||
}
|
||||
ixcj += m.stride
|
||||
}
|
||||
if j != row {
|
||||
pv[row], pv[j] = pv[j], pv[row]
|
||||
}
|
||||
dx += m.stride + 1
|
||||
}
|
||||
p := Zero(m.stride, m.stride)
|
||||
for r, c := range pv {
|
||||
p.elements[r*m.stride+c] = 1
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Decompositions
|
||||
|
||||
// LU performs the LU decomposition.
|
||||
func (m *Matrix) LU() (l, u, p *Matrix) {
|
||||
l = Zero(m.stride, m.stride)
|
||||
u = Zero(m.stride, m.stride)
|
||||
p = m.Pivotize()
|
||||
m, _ = p.Multiply(m)
|
||||
for j, jxc0 := 0, 0; j < m.stride; j++ {
|
||||
l.elements[jxc0+j] = 1
|
||||
for i, ixc0 := 0, 0; ixc0 <= jxc0; i++ {
|
||||
sum := 0.
|
||||
for k, kxcj := 0, j; k < i; k++ {
|
||||
sum += u.elements[kxcj] * l.elements[ixc0+k]
|
||||
kxcj += m.stride
|
||||
}
|
||||
u.elements[ixc0+j] = m.elements[ixc0+j] - sum
|
||||
ixc0 += m.stride
|
||||
}
|
||||
for ixc0 := jxc0; ixc0 < len(m.elements); ixc0 += m.stride {
|
||||
sum := 0.
|
||||
for k, kxcj := 0, j; k < j; k++ {
|
||||
sum += u.elements[kxcj] * l.elements[ixc0+k]
|
||||
kxcj += m.stride
|
||||
}
|
||||
l.elements[ixc0+j] = (m.elements[ixc0+j] - sum) / u.elements[jxc0+j]
|
||||
}
|
||||
jxc0 += m.stride
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// QR performs the qr decomposition.
|
||||
func (m *Matrix) QR() (q, r *Matrix) {
|
||||
defer func() {
|
||||
|
|
|
@ -329,6 +329,21 @@ func TestMatrixString(t *testing.T) {
|
|||
assert.Equal("1 2 3 \n4 5 6 \n7 8 9 \n", m.String())
|
||||
}
|
||||
|
||||
func TestMatrixLU(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
m := NewFromArrays([][]float64{
|
||||
{1, 3, 5},
|
||||
{2, 4, 7},
|
||||
{1, 1, 0},
|
||||
})
|
||||
|
||||
l, u, p := m.LU()
|
||||
assert.NotNil(l)
|
||||
assert.NotNil(u)
|
||||
assert.NotNil(p)
|
||||
}
|
||||
|
||||
func TestMatrixQR(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
|
|
Loading…
Reference in a new issue